99 lines
4.0 KiB
Python
99 lines
4.0 KiB
Python
import httpx
|
|
import asyncio
|
|
from fastapi import HTTPException
|
|
from app.core.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global lock to prevent concurrent switches and generation requests
|
|
# This is safe for a single-worker MVP (uvicorn without --workers)
|
|
inference_lock = asyncio.Lock()
|
|
|
|
class LLMClient:
|
|
def __init__(self):
|
|
self.base_url = settings.llm_manager_base_url.rstrip("/")
|
|
self.api_key = settings.llm_manager_api_key
|
|
self.headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
async def get_status(self):
|
|
"""Fetch the current global state of llm-manager."""
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
response = await client.get(
|
|
f"{self.base_url}/status",
|
|
headers=self.headers,
|
|
timeout=10.0
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"Failed to fetch llm-manager status: {e}")
|
|
raise HTTPException(status_code=502, detail="llm-manager status check failed")
|
|
|
|
async def switch_model(self, model_name: str):
|
|
"""Request llm-manager to switch its active model."""
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
logger.info(f"Requesting llm-manager switch to model: {model_name}")
|
|
response = await client.post(
|
|
f"{self.base_url}/switch/{model_name}",
|
|
headers=self.headers,
|
|
timeout=60.0 # Switching can take a while via LLM manager
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"Failed to switch model to {model_name}: {e}")
|
|
raise HTTPException(status_code=502, detail=f"Failed to switch model to {model_name}")
|
|
|
|
async def wait_for_model_ready(self, model_name: str, timeout: float = 60.0, poll_interval: float = 2.0):
|
|
"""Wait for the model to be active and not loading/unloading."""
|
|
import time
|
|
start_time = time.time()
|
|
iterations = 0
|
|
while time.time() - start_time < timeout:
|
|
iterations += 1
|
|
status = await self.get_status()
|
|
current_model = status.get("active_model")
|
|
vram_state = status.get("vram_state", "")
|
|
|
|
logger.info(f"Readiness poll #{iterations}: model={current_model}, vram_state={vram_state}")
|
|
|
|
if current_model == model_name and vram_state not in ("loading", "unloading"):
|
|
return True, iterations, status
|
|
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
return False, iterations, None
|
|
|
|
async def chat_completion(self, messages: list, max_tokens: int = None, temperature: float = None):
|
|
"""Generate response via llm-manager."""
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
payload = {
|
|
"messages": messages,
|
|
"stream": False
|
|
}
|
|
if max_tokens is not None:
|
|
payload["max_tokens"] = max_tokens
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
|
|
response = await client.post(
|
|
f"{self.base_url}/v1/chat/completions",
|
|
headers=self.headers,
|
|
json=payload,
|
|
timeout=120.0
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"Failed to generate chat completion: {e}")
|
|
raise HTTPException(status_code=502, detail="Chat completion generation failed")
|
|
|
|
llm_client = LLMClient()
|