from __future__ import annotations import json from dataclasses import dataclass, field from typing import Any from urllib import error, request from config import ServerConfig from oauth import OAuthError, QwenOAuthManager class ModelProviderError(RuntimeError): pass @dataclass(slots=True) class ProviderCapabilities: tool_calling: bool = False web_search: bool = False oauth_auth: bool = False @dataclass(slots=True) class CompletionRequest: messages: list[dict[str, Any]] tools: list[dict[str, Any]] tool_choice: str = "auto" preferred_provider: str | None = None require_tools: bool = True @dataclass(slots=True) class CompletionResponse: provider_name: str model_name: str payload: dict[str, Any] selection_reason: str attempted: list[str] = field(default_factory=list) class BaseModelProvider: name = "base" capabilities = ProviderCapabilities() def __init__(self, config: ServerConfig) -> None: self.config = config def is_available(self) -> bool: raise NotImplementedError def unavailable_reason(self) -> str | None: raise NotImplementedError def model_name(self) -> str: raise NotImplementedError def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: raise NotImplementedError class UnavailableModelProvider(BaseModelProvider): def __init__( self, config: ServerConfig, *, name: str, model_name: str, reason: str, capabilities: ProviderCapabilities | None = None, ) -> None: super().__init__(config) self.name = name self._model_name = model_name self._reason = reason if capabilities is not None: self.capabilities = capabilities def is_available(self) -> bool: return False def unavailable_reason(self) -> str | None: return self._reason def model_name(self) -> str: return self._model_name def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}") class QwenModelProvider(BaseModelProvider): name = "qwen" capabilities = ProviderCapabilities( tool_calling=True, web_search=True, oauth_auth=True, ) def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None: super().__init__(config) self.oauth = oauth def is_available(self) -> bool: creds = self.oauth.load_credentials() return bool(creds and creds.get("access_token")) def unavailable_reason(self) -> str | None: if self.is_available(): return None return "Qwen OAuth is not configured" def model_name(self) -> str: return self.config.model def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: creds = self.oauth.get_valid_credentials() base_url = self.oauth.get_openai_base_url(creds) payload = { "model": self.model_name(), "messages": completion_request.messages, "tools": completion_request.tools, "tool_choice": completion_request.tool_choice, } data = json.dumps(payload).encode("utf-8") req = request.Request( f"{base_url}/chat/completions", data=data, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {creds['access_token']}", }, method="POST", ) try: with request.urlopen(req, timeout=180) as response: return json.loads(response.read().decode("utf-8")) except error.HTTPError as exc: body = exc.read().decode("utf-8", errors="replace") raise ModelProviderError( f"Provider {self.name} request failed with HTTP {exc.code}: {body}" ) from exc except OAuthError as exc: raise ModelProviderError(str(exc)) from exc class ProviderRegistry: def __init__(self, providers: list[BaseModelProvider]) -> None: self._providers = {provider.name: provider for provider in providers} def get(self, name: str) -> BaseModelProvider | None: return self._providers.get(name) def list_names(self) -> list[str]: return sorted(self._providers.keys()) def statuses(self) -> list[dict[str, Any]]: return [ { "name": provider.name, "model": provider.model_name(), "available": provider.is_available(), "reason": provider.unavailable_reason(), "capabilities": { "tool_calling": provider.capabilities.tool_calling, "web_search": provider.capabilities.web_search, "oauth_auth": provider.capabilities.oauth_auth, }, } for provider in self._providers.values() ] class ModelRouter: def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None: self.config = config self.registry = registry def _candidate_names(self, preferred_provider: str | None) -> list[str]: names: list[str] = [] for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]: if name and name not in names: names.append(name) for name in self.registry.list_names(): if name not in names: names.append(name) return names def _supports_request( self, provider: BaseModelProvider, completion_request: CompletionRequest, ) -> bool: if completion_request.require_tools and not provider.capabilities.tool_calling: return False return True def complete(self, completion_request: CompletionRequest) -> CompletionResponse: attempted: list[str] = [] reasons: list[str] = [] for name in self._candidate_names(completion_request.preferred_provider): provider = self.registry.get(name) if not provider: reasons.append(f"{name}: unknown provider") continue attempted.append(name) if not provider.is_available(): reason = provider.unavailable_reason() or "provider unavailable" reasons.append(f"{name}: {reason}") continue if not self._supports_request(provider, completion_request): reasons.append(f"{name}: missing required capabilities") continue selection_reason = "selected preferred provider" if completion_request.preferred_provider and name != completion_request.preferred_provider: selection_reason = f"preferred provider unavailable, fell back to {name}" elif not completion_request.preferred_provider and name != self.config.default_provider: selection_reason = f"default provider unavailable, fell back to {name}" elif name == self.config.default_provider: selection_reason = "selected default provider" try: payload = provider.complete(completion_request) except ModelProviderError as exc: reasons.append(f"{name}: request failed: {exc}") continue return CompletionResponse( provider_name=name, model_name=provider.model_name(), payload=payload, selection_reason=selection_reason, attempted=attempted, ) details = "; ".join(reasons) if reasons else "no providers registered" raise ModelProviderError(f"No model provider available: {details}") def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry: providers: list[BaseModelProvider] = [ QwenModelProvider(config, oauth), UnavailableModelProvider( config, name="gigachat", model_name=config.gigachat_model, reason="GigaChat provider is not implemented yet", capabilities=ProviderCapabilities(), ), UnavailableModelProvider( config, name="yandexgpt", model_name=config.yandexgpt_model, reason="YandexGPT provider is not implemented yet", capabilities=ProviderCapabilities(), ), ] return ProviderRegistry(providers)