from __future__ import annotations import json from dataclasses import dataclass, field from typing import Any import uuid from urllib import error, request from config import ServerConfig from gigachat import GigaChatAuthManager, GigaChatError from oauth import OAuthError, QwenOAuthManager, QWEN_OAUTH_ALLOWED_MODELS class ModelProviderError(RuntimeError): pass @dataclass(slots=True) class ProviderCapabilities: tool_calling: bool = False web_search: bool = False oauth_auth: bool = False @dataclass(slots=True) class CompletionRequest: messages: list[dict[str, Any]] tools: list[dict[str, Any]] tool_choice: str = "auto" preferred_provider: str | None = None require_tools: bool = True @dataclass(slots=True) class CompletionResponse: provider_name: str model_name: str payload: dict[str, Any] selection_reason: str attempted: list[str] = field(default_factory=list) class BaseModelProvider: name = "base" capabilities = ProviderCapabilities() def __init__(self, config: ServerConfig) -> None: self.config = config def is_available(self) -> bool: raise NotImplementedError def unavailable_reason(self) -> str | None: raise NotImplementedError def model_name(self) -> str: raise NotImplementedError def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: raise NotImplementedError class UnavailableModelProvider(BaseModelProvider): def __init__( self, config: ServerConfig, *, name: str, model_name: str, reason: str, capabilities: ProviderCapabilities | None = None, ) -> None: super().__init__(config) self.name = name self._model_name = model_name self._reason = reason if capabilities is not None: self.capabilities = capabilities def is_available(self) -> bool: return False def unavailable_reason(self) -> str | None: return self._reason def model_name(self) -> str: return self._model_name def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}") class QwenModelProvider(BaseModelProvider): name = "qwen" capabilities = ProviderCapabilities( tool_calling=True, web_search=True, oauth_auth=True, ) def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None: super().__init__(config) self.oauth = oauth # Resolve model ID to actual model name self._model_id = config.model if config.model in QWEN_OAUTH_ALLOWED_MODELS else "coder-model" self._model_name = oauth.get_model_name_for_id(self._model_id) def is_available(self) -> bool: creds = self.oauth.load_credentials() return bool(creds and creds.get("access_token")) def unavailable_reason(self) -> str | None: if self.is_available(): return None return "Qwen OAuth is not configured" def model_name(self) -> str: return self._model_name @staticmethod def _normalize_content(value: Any) -> list[dict[str, str]]: if isinstance(value, list): normalized: list[dict[str, str]] = [] for item in value: if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str): normalized.append({"type": "text", "text": item["text"]}) if normalized: return normalized text = "" if value is None else str(value) return [{"type": "text", "text": text}] def _normalize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] for message in messages: item = {k: v for k, v in message.items() if k != "content"} item["content"] = self._normalize_content(message.get("content")) normalized.append(item) return normalized def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: creds = self.oauth.get_valid_credentials() base_url = self.oauth.get_openai_base_url(creds) payload = { "model": self.model_name(), "messages": self._normalize_messages(completion_request.messages), "max_tokens": 8000, "metadata": { "sessionId": str(uuid.uuid4()), "promptId": uuid.uuid4().hex[:12], }, "vl_high_resolution_images": True, } if completion_request.tools: payload["tools"] = completion_request.tools data = json.dumps(payload).encode("utf-8") # Add DashScope-specific headers for OAuth tokens user_agent = "QwenCode/unknown (linux; x64)" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {creds['access_token']}", "User-Agent": user_agent, "X-DashScope-CacheControl": "enable", "X-DashScope-UserAgent": user_agent, "X-DashScope-AuthType": "qwen-oauth", "Accept": "application/json", } req = request.Request( f"{base_url}/chat/completions", data=data, headers=headers, method="POST", ) try: with request.urlopen(req, timeout=180) as response: return json.loads(response.read().decode("utf-8")) except error.HTTPError as exc: body = exc.read().decode("utf-8", errors="replace") raise ModelProviderError( f"Provider {self.name} request failed with HTTP {exc.code}: {body}" ) from exc except OAuthError as exc: raise ModelProviderError(str(exc)) from exc class GigaChatModelProvider(BaseModelProvider): name = "gigachat" capabilities = ProviderCapabilities( tool_calling=True, web_search=False, oauth_auth=False, ) def __init__(self, config: ServerConfig, auth: GigaChatAuthManager) -> None: super().__init__(config) self.auth = auth def is_available(self) -> bool: return self.auth.is_configured() def unavailable_reason(self) -> str | None: if self.is_available(): return None return "GigaChat auth key is not configured" def model_name(self) -> str: return self.config.gigachat_model @staticmethod def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]: return dict(tool.get("function") or {}) def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: converted: list[dict[str, Any]] = [] for message in messages: role = message.get("role") if role in {"system", "user"}: converted.append( { "role": role, "content": message.get("content", ""), } ) continue if role == "assistant": payload: dict[str, Any] = { "role": "assistant", "content": message.get("content", ""), } tool_calls = message.get("tool_calls") or [] if tool_calls: first_call = tool_calls[0] raw_arguments = first_call.get("function", {}).get("arguments", "{}") if isinstance(raw_arguments, str): try: arguments = json.loads(raw_arguments) except json.JSONDecodeError: arguments = {"raw": raw_arguments} else: arguments = raw_arguments payload["function_call"] = { "name": first_call.get("function", {}).get("name"), "arguments": arguments, } if message.get("functions_state_id"): payload["functions_state_id"] = message["functions_state_id"] converted.append(payload) continue if role == "tool": converted.append( { "role": "function", "name": message.get("name") or "tool_result", "content": message.get("content", ""), } ) return converted def _normalize_response(self, payload: dict[str, Any]) -> dict[str, Any]: choices = payload.get("choices") or [] if not choices: return payload choice = choices[0] message = dict(choice.get("message") or {}) normalized_message: dict[str, Any] = { "role": "assistant", "content": message.get("content", "") or "", } function_call = message.get("function_call") if function_call: arguments = function_call.get("arguments", {}) if not isinstance(arguments, str): arguments = json.dumps(arguments, ensure_ascii=False) normalized_message["tool_calls"] = [ { "id": uuid.uuid4().hex, "type": "function", "function": { "name": function_call.get("name"), "arguments": arguments, }, } ] if message.get("functions_state_id"): normalized_message["functions_state_id"] = message.get("functions_state_id") choice["message"] = normalized_message payload["choices"] = choices return payload def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: try: access_token = self.auth.get_valid_token() except GigaChatError as exc: raise ModelProviderError(str(exc)) from exc api_base = self.config.gigachat_api_base_url.rstrip("/") payload: dict[str, Any] = { "model": self.model_name(), "messages": self._convert_messages(completion_request.messages), } if completion_request.tools: payload["functions"] = [ self._convert_tool_schema(tool) for tool in completion_request.tools ] payload["function_call"] = "auto" data = json.dumps(payload, ensure_ascii=False).encode("utf-8") req = request.Request( f"{api_base}/chat/completions", data=data, headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {access_token}", }, method="POST", ) try: with request.urlopen(req, timeout=180) as response: raw = json.loads(response.read().decode("utf-8")) except error.HTTPError as exc: body = exc.read().decode("utf-8", errors="replace") raise ModelProviderError( f"Provider {self.name} request failed with HTTP {exc.code}: {body}" ) from exc return self._normalize_response(raw) class ProviderRegistry: def __init__(self, providers: list[BaseModelProvider]) -> None: self._providers = {provider.name: provider for provider in providers} def get(self, name: str) -> BaseModelProvider | None: return self._providers.get(name) def list_names(self) -> list[str]: return sorted(self._providers.keys()) def statuses(self) -> list[dict[str, Any]]: return [ { "name": provider.name, "model": provider.model_name(), "available": provider.is_available(), "reason": provider.unavailable_reason(), "capabilities": { "tool_calling": provider.capabilities.tool_calling, "web_search": provider.capabilities.web_search, "oauth_auth": provider.capabilities.oauth_auth, }, } for provider in self._providers.values() ] class ModelRouter: def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None: self.config = config self.registry = registry def _candidate_names(self, preferred_provider: str | None) -> list[str]: names: list[str] = [] for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]: if name and name not in names: names.append(name) for name in self.registry.list_names(): if name not in names: names.append(name) return names def _supports_request( self, provider: BaseModelProvider, completion_request: CompletionRequest, ) -> bool: if completion_request.require_tools and not provider.capabilities.tool_calling: return False return True def complete(self, completion_request: CompletionRequest) -> CompletionResponse: attempted: list[str] = [] reasons: list[str] = [] for name in self._candidate_names(completion_request.preferred_provider): provider = self.registry.get(name) if not provider: reasons.append(f"{name}: unknown provider") continue attempted.append(name) if not provider.is_available(): reason = provider.unavailable_reason() or "provider unavailable" reasons.append(f"{name}: {reason}") continue if not self._supports_request(provider, completion_request): reasons.append(f"{name}: missing required capabilities") continue selection_reason = "selected preferred provider" if completion_request.preferred_provider and name != completion_request.preferred_provider: selection_reason = f"preferred provider unavailable, fell back to {name}" elif not completion_request.preferred_provider and name != self.config.default_provider: selection_reason = f"default provider unavailable, fell back to {name}" elif name == self.config.default_provider: selection_reason = "selected default provider" try: payload = provider.complete(completion_request) except ModelProviderError as exc: reasons.append(f"{name}: request failed: {exc}") continue return CompletionResponse( provider_name=name, model_name=provider.model_name(), payload=payload, selection_reason=selection_reason, attempted=attempted, ) details = "; ".join(reasons) if reasons else "no providers registered" raise ModelProviderError(f"No model provider available: {details}") def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry: providers: list[BaseModelProvider] = [ QwenModelProvider(config, oauth), GigaChatModelProvider(config, GigaChatAuthManager(config)), UnavailableModelProvider( config, name="yandexgpt", model_name=config.yandexgpt_model, reason="YandexGPT provider is not implemented yet", capabilities=ProviderCapabilities(), ), ] return ProviderRegistry(providers)