Files
chat-frontend/backend/app/core/llm_client.py

99 lines
4.0 KiB
Python

import httpx
import asyncio
from fastapi import HTTPException
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
# Global lock to prevent concurrent switches and generation requests
# This is safe for a single-worker MVP (uvicorn without --workers)
inference_lock = asyncio.Lock()
class LLMClient:
def __init__(self):
self.base_url = settings.llm_manager_base_url.rstrip("/")
self.api_key = settings.llm_manager_api_key
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
async def get_status(self):
"""Fetch the current global state of llm-manager."""
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{self.base_url}/status",
headers=self.headers,
timeout=10.0
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"Failed to fetch llm-manager status: {e}")
raise HTTPException(status_code=502, detail="llm-manager status check failed")
async def switch_model(self, model_name: str):
"""Request llm-manager to switch its active model."""
async with httpx.AsyncClient() as client:
try:
logger.info(f"Requesting llm-manager switch to model: {model_name}")
response = await client.post(
f"{self.base_url}/switch/{model_name}",
headers=self.headers,
timeout=60.0 # Switching can take a while via LLM manager
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"Failed to switch model to {model_name}: {e}")
raise HTTPException(status_code=502, detail=f"Failed to switch model to {model_name}")
async def wait_for_model_ready(self, model_name: str, timeout: float = 60.0, poll_interval: float = 2.0):
"""Wait for the model to be active and not loading/unloading."""
import time
start_time = time.time()
iterations = 0
while time.time() - start_time < timeout:
iterations += 1
status = await self.get_status()
current_model = status.get("active_model")
vram_state = status.get("vram_state", "")
logger.info(f"Readiness poll #{iterations}: model={current_model}, vram_state={vram_state}")
if current_model == model_name and vram_state not in ("loading", "unloading"):
return True, iterations, status
await asyncio.sleep(poll_interval)
return False, iterations, None
async def chat_completion(self, messages: list, max_tokens: int = None, temperature: float = None):
"""Generate response via llm-manager."""
async with httpx.AsyncClient() as client:
try:
payload = {
"messages": messages,
"stream": False
}
if max_tokens is not None:
payload["max_tokens"] = max_tokens
if temperature is not None:
payload["temperature"] = temperature
response = await client.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"Failed to generate chat completion: {e}")
raise HTTPException(status_code=502, detail="Chat completion generation failed")
llm_client = LLMClient()