new-qwen/serv/model_router.py

255 lines
8.5 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
from urllib import error, request
from config import ServerConfig
from oauth import OAuthError, QwenOAuthManager
class ModelProviderError(RuntimeError):
pass
@dataclass(slots=True)
class ProviderCapabilities:
tool_calling: bool = False
web_search: bool = False
oauth_auth: bool = False
@dataclass(slots=True)
class CompletionRequest:
messages: list[dict[str, Any]]
tools: list[dict[str, Any]]
tool_choice: str = "auto"
preferred_provider: str | None = None
require_tools: bool = True
@dataclass(slots=True)
class CompletionResponse:
provider_name: str
model_name: str
payload: dict[str, Any]
selection_reason: str
attempted: list[str] = field(default_factory=list)
class BaseModelProvider:
name = "base"
capabilities = ProviderCapabilities()
def __init__(self, config: ServerConfig) -> None:
self.config = config
def is_available(self) -> bool:
raise NotImplementedError
def unavailable_reason(self) -> str | None:
raise NotImplementedError
def model_name(self) -> str:
raise NotImplementedError
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
raise NotImplementedError
class UnavailableModelProvider(BaseModelProvider):
def __init__(
self,
config: ServerConfig,
*,
name: str,
model_name: str,
reason: str,
capabilities: ProviderCapabilities | None = None,
) -> None:
super().__init__(config)
self.name = name
self._model_name = model_name
self._reason = reason
if capabilities is not None:
self.capabilities = capabilities
def is_available(self) -> bool:
return False
def unavailable_reason(self) -> str | None:
return self._reason
def model_name(self) -> str:
return self._model_name
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}")
class QwenModelProvider(BaseModelProvider):
name = "qwen"
capabilities = ProviderCapabilities(
tool_calling=True,
web_search=True,
oauth_auth=True,
)
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
super().__init__(config)
self.oauth = oauth
def is_available(self) -> bool:
creds = self.oauth.load_credentials()
return bool(creds and creds.get("access_token"))
def unavailable_reason(self) -> str | None:
if self.is_available():
return None
return "Qwen OAuth is not configured"
def model_name(self) -> str:
return self.config.model
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
creds = self.oauth.get_valid_credentials()
base_url = self.oauth.get_openai_base_url(creds)
payload = {
"model": self.model_name(),
"messages": completion_request.messages,
"tools": completion_request.tools,
"tool_choice": completion_request.tool_choice,
}
data = json.dumps(payload).encode("utf-8")
req = request.Request(
f"{base_url}/chat/completions",
data=data,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {creds['access_token']}",
},
method="POST",
)
try:
with request.urlopen(req, timeout=180) as response:
return json.loads(response.read().decode("utf-8"))
except error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise ModelProviderError(
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
) from exc
except OAuthError as exc:
raise ModelProviderError(str(exc)) from exc
class ProviderRegistry:
def __init__(self, providers: list[BaseModelProvider]) -> None:
self._providers = {provider.name: provider for provider in providers}
def get(self, name: str) -> BaseModelProvider | None:
return self._providers.get(name)
def list_names(self) -> list[str]:
return sorted(self._providers.keys())
def statuses(self) -> list[dict[str, Any]]:
return [
{
"name": provider.name,
"model": provider.model_name(),
"available": provider.is_available(),
"reason": provider.unavailable_reason(),
"capabilities": {
"tool_calling": provider.capabilities.tool_calling,
"web_search": provider.capabilities.web_search,
"oauth_auth": provider.capabilities.oauth_auth,
},
}
for provider in self._providers.values()
]
class ModelRouter:
def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None:
self.config = config
self.registry = registry
def _candidate_names(self, preferred_provider: str | None) -> list[str]:
names: list[str] = []
for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]:
if name and name not in names:
names.append(name)
for name in self.registry.list_names():
if name not in names:
names.append(name)
return names
def _supports_request(
self,
provider: BaseModelProvider,
completion_request: CompletionRequest,
) -> bool:
if completion_request.require_tools and not provider.capabilities.tool_calling:
return False
return True
def complete(self, completion_request: CompletionRequest) -> CompletionResponse:
attempted: list[str] = []
reasons: list[str] = []
for name in self._candidate_names(completion_request.preferred_provider):
provider = self.registry.get(name)
if not provider:
reasons.append(f"{name}: unknown provider")
continue
attempted.append(name)
if not provider.is_available():
reason = provider.unavailable_reason() or "provider unavailable"
reasons.append(f"{name}: {reason}")
continue
if not self._supports_request(provider, completion_request):
reasons.append(f"{name}: missing required capabilities")
continue
selection_reason = "selected preferred provider"
if completion_request.preferred_provider and name != completion_request.preferred_provider:
selection_reason = f"preferred provider unavailable, fell back to {name}"
elif not completion_request.preferred_provider and name != self.config.default_provider:
selection_reason = f"default provider unavailable, fell back to {name}"
elif name == self.config.default_provider:
selection_reason = "selected default provider"
try:
payload = provider.complete(completion_request)
except ModelProviderError as exc:
reasons.append(f"{name}: request failed: {exc}")
continue
return CompletionResponse(
provider_name=name,
model_name=provider.model_name(),
payload=payload,
selection_reason=selection_reason,
attempted=attempted,
)
details = "; ".join(reasons) if reasons else "no providers registered"
raise ModelProviderError(f"No model provider available: {details}")
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
providers: list[BaseModelProvider] = [
QwenModelProvider(config, oauth),
UnavailableModelProvider(
config,
name="gigachat",
model_name=config.gigachat_model,
reason="GigaChat provider is not implemented yet",
capabilities=ProviderCapabilities(),
),
UnavailableModelProvider(
config,
name="yandexgpt",
model_name=config.yandexgpt_model,
reason="YandexGPT provider is not implemented yet",
capabilities=ProviderCapabilities(),
),
]
return ProviderRegistry(providers)