import httpx import asyncio from fastapi import HTTPException from app.core.config import settings import logging logger = logging.getLogger(__name__) # Global lock to prevent concurrent switches and generation requests # This is safe for a single-worker MVP (uvicorn without --workers) inference_lock = asyncio.Lock() class LLMClient: def __init__(self): self.base_url = settings.llm_manager_base_url.rstrip("/") self.api_key = settings.llm_manager_api_key self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } async def get_status(self): """Fetch the current global state of llm-manager.""" async with httpx.AsyncClient() as client: try: response = await client.get( f"{self.base_url}/status", headers=self.headers, timeout=10.0 ) response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"Failed to fetch llm-manager status: {e}") raise HTTPException(status_code=502, detail="llm-manager status check failed") async def switch_model(self, model_name: str): """Request llm-manager to switch its active model.""" async with httpx.AsyncClient() as client: try: logger.info(f"Requesting llm-manager switch to model: {model_name}") response = await client.post( f"{self.base_url}/switch/{model_name}", headers=self.headers, timeout=60.0 # Switching can take a while via LLM manager ) response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"Failed to switch model to {model_name}: {e}") raise HTTPException(status_code=502, detail=f"Failed to switch model to {model_name}") async def wait_for_model_ready(self, model_name: str, timeout: float = 60.0, poll_interval: float = 2.0): """Wait for the model to be active and not loading/unloading.""" import time start_time = time.time() iterations = 0 while time.time() - start_time < timeout: iterations += 1 status = await self.get_status() current_model = status.get("active_model") vram_state = status.get("vram_state", "") logger.info(f"Readiness poll #{iterations}: model={current_model}, vram_state={vram_state}") if current_model == model_name and vram_state not in ("loading", "unloading"): return True, iterations, status await asyncio.sleep(poll_interval) return False, iterations, None async def chat_completion(self, messages: list, max_tokens: int = None, temperature: float = None): """Generate response via llm-manager.""" async with httpx.AsyncClient() as client: try: payload = { "messages": messages, "stream": False } if max_tokens is not None: payload["max_tokens"] = max_tokens if temperature is not None: payload["temperature"] = temperature response = await client.post( f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0 ) response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"Failed to generate chat completion: {e}") raise HTTPException(status_code=502, detail="Chat completion generation failed") llm_client = LLMClient()