ducklm/duck_core/model_client.py

218 lines
8.1 KiB
Python

import json
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import httpx
import yaml
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RoleConfig:
role: str
provider: str
base_url: str
model: str
purpose: str
structured_output: bool
temperature: float
max_output_tokens: int
system_prompt: str
response_schema: str | None = None
@dataclass
class ModelResponse:
role: str
model: str
content: str
reasoning_content: str | None
raw: dict[str, Any]
latency_ms: float
prompt_tokens: int | None = None
completion_tokens: int | None = None
total_tokens: int | None = None
class ModelClient:
def __init__(self, config_path: str = "config/models.yaml", timeout: float = 120.0):
self.config_path = Path(config_path)
self.timeout = timeout
data = yaml.safe_load(self.config_path.read_text())
self.default_provider = data["default_provider"]
self._roles = {
role: RoleConfig(role=role, **settings)
for role, settings in data["models"].items()
}
def list_roles(self) -> dict[str, dict[str, Any]]:
return {
role: {
"provider": cfg.provider,
"base_url": cfg.base_url,
"model": cfg.model,
"purpose": cfg.purpose,
"structured_output": cfg.structured_output,
"temperature": cfg.temperature,
"max_output_tokens": cfg.max_output_tokens,
"system_prompt": cfg.system_prompt,
"response_schema": cfg.response_schema,
}
for role, cfg in self._roles.items()
}
def get_role_config(self, role: str) -> RoleConfig:
try:
return self._roles[role]
except KeyError as exc:
raise KeyError(f"Unknown model role: {role}") from exc
def _system_message(self, cfg: RoleConfig) -> dict[str, str] | None:
path = Path(cfg.system_prompt)
if not path.exists():
return None
return {"role": "system", "content": path.read_text()}
def _response_format(
self, cfg: RoleConfig, response_format: dict[str, Any] | None
) -> dict[str, Any] | None:
if response_format is not None:
return response_format
if not cfg.structured_output:
return None
if cfg.response_schema and Path(cfg.response_schema).exists():
schema = json.loads(Path(cfg.response_schema).read_text())
return {
"type": "json_schema",
"json_schema": {"name": "action_directive", "schema": schema, "strict": True},
}
return {"type": "json_object"}
async def chat(
self,
role: str,
messages: list[dict[str, str]],
temperature: float | None = None,
max_output_tokens: int | None = None,
response_format: dict[str, Any] | None = None,
) -> ModelResponse:
cfg = self.get_role_config(role)
outbound = list(messages)
system_message = self._system_message(cfg)
if system_message and not any(message["role"] == "system" for message in outbound):
outbound.insert(0, system_message)
payload: dict[str, Any] = {
"model": cfg.model,
"messages": outbound,
"temperature": cfg.temperature if temperature is None else temperature,
"max_tokens": cfg.max_output_tokens if max_output_tokens is None else max_output_tokens,
}
fmt = self._response_format(cfg, response_format)
if fmt is not None:
payload["response_format"] = fmt
start = time.perf_counter()
try:
async with httpx.AsyncClient(timeout=self.timeout, trust_env=False) as client:
response = await client.post(f"{cfg.base_url}/chat/completions", json=payload)
response.raise_for_status()
raw = response.json()
except httpx.HTTPError as exc:
raise ConnectionError(f"Model backend unavailable for role {role}: {exc}") from exc
latency_ms = (time.perf_counter() - start) * 1000
usage = raw.get("usage") or {}
message = raw.get("choices", [{}])[0].get("message", {})
content = message.get("content") or ""
reasoning_content = message.get("reasoning_content")
logger.info("model role=%s model=%s latency_ms=%.1f usage=%s", role, cfg.model, latency_ms, usage)
return ModelResponse(
role=role,
model=cfg.model,
content=content,
reasoning_content=reasoning_content,
raw=raw,
latency_ms=latency_ms,
prompt_tokens=usage.get("prompt_tokens"),
completion_tokens=usage.get("completion_tokens"),
total_tokens=usage.get("total_tokens"),
)
async def stream_chat(
self,
role: str,
messages: list[dict[str, str]],
temperature: float | None = None,
max_output_tokens: int | None = None,
response_format: dict[str, Any] | None = None,
):
cfg = self.get_role_config(role)
outbound = list(messages)
system_message = self._system_message(cfg)
if system_message and not any(message["role"] == "system" for message in outbound):
outbound.insert(0, system_message)
payload: dict[str, Any] = {
"model": cfg.model,
"messages": outbound,
"temperature": cfg.temperature if temperature is None else temperature,
"max_tokens": cfg.max_output_tokens if max_output_tokens is None else max_output_tokens,
"stream": True,
}
fmt = self._response_format(cfg, response_format)
if fmt is not None:
payload["response_format"] = fmt
try:
async with httpx.AsyncClient(timeout=self.timeout, trust_env=False) as client:
async with client.stream(
"POST", f"{cfg.base_url}/chat/completions", json=payload
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
raw_data = line.removeprefix("data: ").strip()
if raw_data == "[DONE]":
break
if not raw_data:
continue
chunk = json.loads(raw_data)
delta = chunk.get("choices", [{}])[0].get("delta", {})
reasoning_delta = delta.get("reasoning_content")
content_delta = delta.get("content")
if reasoning_delta:
yield {"type": "reasoning_delta", "delta": reasoning_delta}
if content_delta:
yield {"type": "content_delta", "delta": content_delta}
except httpx.HTTPError as exc:
raise ConnectionError(f"Model backend unavailable for role {role}: {exc}") from exc
async def ping(self) -> dict[str, Any]:
results: dict[str, Any] = {}
async with httpx.AsyncClient(timeout=10.0, trust_env=False) as client:
for role, cfg in self._roles.items():
try:
started = time.perf_counter()
response = await client.get(f"{cfg.base_url}/models")
response.raise_for_status()
results[role] = {
"ok": True,
"base_url": cfg.base_url,
"model": cfg.model,
"latency_ms": round((time.perf_counter() - started) * 1000, 1),
}
except httpx.HTTPError as exc:
results[role] = {
"ok": False,
"base_url": cfg.base_url,
"model": cfg.model,
"error": str(exc),
}
return results