218 lines
8.1 KiB
Python
218 lines
8.1 KiB
Python
import json
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import yaml
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RoleConfig:
|
|
role: str
|
|
provider: str
|
|
base_url: str
|
|
model: str
|
|
purpose: str
|
|
structured_output: bool
|
|
temperature: float
|
|
max_output_tokens: int
|
|
system_prompt: str
|
|
response_schema: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class ModelResponse:
|
|
role: str
|
|
model: str
|
|
content: str
|
|
reasoning_content: str | None
|
|
raw: dict[str, Any]
|
|
latency_ms: float
|
|
prompt_tokens: int | None = None
|
|
completion_tokens: int | None = None
|
|
total_tokens: int | None = None
|
|
|
|
|
|
class ModelClient:
|
|
def __init__(self, config_path: str = "config/models.yaml", timeout: float = 120.0):
|
|
self.config_path = Path(config_path)
|
|
self.timeout = timeout
|
|
data = yaml.safe_load(self.config_path.read_text())
|
|
self.default_provider = data["default_provider"]
|
|
self._roles = {
|
|
role: RoleConfig(role=role, **settings)
|
|
for role, settings in data["models"].items()
|
|
}
|
|
|
|
def list_roles(self) -> dict[str, dict[str, Any]]:
|
|
return {
|
|
role: {
|
|
"provider": cfg.provider,
|
|
"base_url": cfg.base_url,
|
|
"model": cfg.model,
|
|
"purpose": cfg.purpose,
|
|
"structured_output": cfg.structured_output,
|
|
"temperature": cfg.temperature,
|
|
"max_output_tokens": cfg.max_output_tokens,
|
|
"system_prompt": cfg.system_prompt,
|
|
"response_schema": cfg.response_schema,
|
|
}
|
|
for role, cfg in self._roles.items()
|
|
}
|
|
|
|
def get_role_config(self, role: str) -> RoleConfig:
|
|
try:
|
|
return self._roles[role]
|
|
except KeyError as exc:
|
|
raise KeyError(f"Unknown model role: {role}") from exc
|
|
|
|
def _system_message(self, cfg: RoleConfig) -> dict[str, str] | None:
|
|
path = Path(cfg.system_prompt)
|
|
if not path.exists():
|
|
return None
|
|
return {"role": "system", "content": path.read_text()}
|
|
|
|
def _response_format(
|
|
self, cfg: RoleConfig, response_format: dict[str, Any] | None
|
|
) -> dict[str, Any] | None:
|
|
if response_format is not None:
|
|
return response_format
|
|
if not cfg.structured_output:
|
|
return None
|
|
if cfg.response_schema and Path(cfg.response_schema).exists():
|
|
schema = json.loads(Path(cfg.response_schema).read_text())
|
|
return {
|
|
"type": "json_schema",
|
|
"json_schema": {"name": "action_directive", "schema": schema, "strict": True},
|
|
}
|
|
return {"type": "json_object"}
|
|
|
|
async def chat(
|
|
self,
|
|
role: str,
|
|
messages: list[dict[str, str]],
|
|
temperature: float | None = None,
|
|
max_output_tokens: int | None = None,
|
|
response_format: dict[str, Any] | None = None,
|
|
) -> ModelResponse:
|
|
cfg = self.get_role_config(role)
|
|
outbound = list(messages)
|
|
system_message = self._system_message(cfg)
|
|
if system_message and not any(message["role"] == "system" for message in outbound):
|
|
outbound.insert(0, system_message)
|
|
|
|
payload: dict[str, Any] = {
|
|
"model": cfg.model,
|
|
"messages": outbound,
|
|
"temperature": cfg.temperature if temperature is None else temperature,
|
|
"max_tokens": cfg.max_output_tokens if max_output_tokens is None else max_output_tokens,
|
|
}
|
|
fmt = self._response_format(cfg, response_format)
|
|
if fmt is not None:
|
|
payload["response_format"] = fmt
|
|
|
|
start = time.perf_counter()
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout, trust_env=False) as client:
|
|
response = await client.post(f"{cfg.base_url}/chat/completions", json=payload)
|
|
response.raise_for_status()
|
|
raw = response.json()
|
|
except httpx.HTTPError as exc:
|
|
raise ConnectionError(f"Model backend unavailable for role {role}: {exc}") from exc
|
|
|
|
latency_ms = (time.perf_counter() - start) * 1000
|
|
usage = raw.get("usage") or {}
|
|
message = raw.get("choices", [{}])[0].get("message", {})
|
|
content = message.get("content") or ""
|
|
reasoning_content = message.get("reasoning_content")
|
|
logger.info("model role=%s model=%s latency_ms=%.1f usage=%s", role, cfg.model, latency_ms, usage)
|
|
return ModelResponse(
|
|
role=role,
|
|
model=cfg.model,
|
|
content=content,
|
|
reasoning_content=reasoning_content,
|
|
raw=raw,
|
|
latency_ms=latency_ms,
|
|
prompt_tokens=usage.get("prompt_tokens"),
|
|
completion_tokens=usage.get("completion_tokens"),
|
|
total_tokens=usage.get("total_tokens"),
|
|
)
|
|
|
|
async def stream_chat(
|
|
self,
|
|
role: str,
|
|
messages: list[dict[str, str]],
|
|
temperature: float | None = None,
|
|
max_output_tokens: int | None = None,
|
|
response_format: dict[str, Any] | None = None,
|
|
):
|
|
cfg = self.get_role_config(role)
|
|
outbound = list(messages)
|
|
system_message = self._system_message(cfg)
|
|
if system_message and not any(message["role"] == "system" for message in outbound):
|
|
outbound.insert(0, system_message)
|
|
|
|
payload: dict[str, Any] = {
|
|
"model": cfg.model,
|
|
"messages": outbound,
|
|
"temperature": cfg.temperature if temperature is None else temperature,
|
|
"max_tokens": cfg.max_output_tokens if max_output_tokens is None else max_output_tokens,
|
|
"stream": True,
|
|
}
|
|
fmt = self._response_format(cfg, response_format)
|
|
if fmt is not None:
|
|
payload["response_format"] = fmt
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout, trust_env=False) as client:
|
|
async with client.stream(
|
|
"POST", f"{cfg.base_url}/chat/completions", json=payload
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
raw_data = line.removeprefix("data: ").strip()
|
|
if raw_data == "[DONE]":
|
|
break
|
|
if not raw_data:
|
|
continue
|
|
chunk = json.loads(raw_data)
|
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
reasoning_delta = delta.get("reasoning_content")
|
|
content_delta = delta.get("content")
|
|
if reasoning_delta:
|
|
yield {"type": "reasoning_delta", "delta": reasoning_delta}
|
|
if content_delta:
|
|
yield {"type": "content_delta", "delta": content_delta}
|
|
except httpx.HTTPError as exc:
|
|
raise ConnectionError(f"Model backend unavailable for role {role}: {exc}") from exc
|
|
|
|
async def ping(self) -> dict[str, Any]:
|
|
results: dict[str, Any] = {}
|
|
async with httpx.AsyncClient(timeout=10.0, trust_env=False) as client:
|
|
for role, cfg in self._roles.items():
|
|
try:
|
|
started = time.perf_counter()
|
|
response = await client.get(f"{cfg.base_url}/models")
|
|
response.raise_for_status()
|
|
results[role] = {
|
|
"ok": True,
|
|
"base_url": cfg.base_url,
|
|
"model": cfg.model,
|
|
"latency_ms": round((time.perf_counter() - started) * 1000, 1),
|
|
}
|
|
except httpx.HTTPError as exc:
|
|
results[role] = {
|
|
"ok": False,
|
|
"base_url": cfg.base_url,
|
|
"model": cfg.model,
|
|
"error": str(exc),
|
|
}
|
|
return results
|