255 lines
8.5 KiB
Python
255 lines
8.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
from urllib import error, request
|
|
|
|
from config import ServerConfig
|
|
from oauth import OAuthError, QwenOAuthManager
|
|
|
|
|
|
class ModelProviderError(RuntimeError):
|
|
pass
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ProviderCapabilities:
|
|
tool_calling: bool = False
|
|
web_search: bool = False
|
|
oauth_auth: bool = False
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CompletionRequest:
|
|
messages: list[dict[str, Any]]
|
|
tools: list[dict[str, Any]]
|
|
tool_choice: str = "auto"
|
|
preferred_provider: str | None = None
|
|
require_tools: bool = True
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CompletionResponse:
|
|
provider_name: str
|
|
model_name: str
|
|
payload: dict[str, Any]
|
|
selection_reason: str
|
|
attempted: list[str] = field(default_factory=list)
|
|
|
|
|
|
class BaseModelProvider:
|
|
name = "base"
|
|
capabilities = ProviderCapabilities()
|
|
|
|
def __init__(self, config: ServerConfig) -> None:
|
|
self.config = config
|
|
|
|
def is_available(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
raise NotImplementedError
|
|
|
|
def model_name(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class UnavailableModelProvider(BaseModelProvider):
|
|
def __init__(
|
|
self,
|
|
config: ServerConfig,
|
|
*,
|
|
name: str,
|
|
model_name: str,
|
|
reason: str,
|
|
capabilities: ProviderCapabilities | None = None,
|
|
) -> None:
|
|
super().__init__(config)
|
|
self.name = name
|
|
self._model_name = model_name
|
|
self._reason = reason
|
|
if capabilities is not None:
|
|
self.capabilities = capabilities
|
|
|
|
def is_available(self) -> bool:
|
|
return False
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
return self._reason
|
|
|
|
def model_name(self) -> str:
|
|
return self._model_name
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}")
|
|
|
|
|
|
class QwenModelProvider(BaseModelProvider):
|
|
name = "qwen"
|
|
capabilities = ProviderCapabilities(
|
|
tool_calling=True,
|
|
web_search=True,
|
|
oauth_auth=True,
|
|
)
|
|
|
|
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
|
|
super().__init__(config)
|
|
self.oauth = oauth
|
|
|
|
def is_available(self) -> bool:
|
|
creds = self.oauth.load_credentials()
|
|
return bool(creds and creds.get("access_token"))
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
if self.is_available():
|
|
return None
|
|
return "Qwen OAuth is not configured"
|
|
|
|
def model_name(self) -> str:
|
|
return self.config.model
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
creds = self.oauth.get_valid_credentials()
|
|
base_url = self.oauth.get_openai_base_url(creds)
|
|
payload = {
|
|
"model": self.model_name(),
|
|
"messages": completion_request.messages,
|
|
"tools": completion_request.tools,
|
|
"tool_choice": completion_request.tool_choice,
|
|
}
|
|
data = json.dumps(payload).encode("utf-8")
|
|
req = request.Request(
|
|
f"{base_url}/chat/completions",
|
|
data=data,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {creds['access_token']}",
|
|
},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with request.urlopen(req, timeout=180) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
except error.HTTPError as exc:
|
|
body = exc.read().decode("utf-8", errors="replace")
|
|
raise ModelProviderError(
|
|
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
|
|
) from exc
|
|
except OAuthError as exc:
|
|
raise ModelProviderError(str(exc)) from exc
|
|
|
|
|
|
class ProviderRegistry:
|
|
def __init__(self, providers: list[BaseModelProvider]) -> None:
|
|
self._providers = {provider.name: provider for provider in providers}
|
|
|
|
def get(self, name: str) -> BaseModelProvider | None:
|
|
return self._providers.get(name)
|
|
|
|
def list_names(self) -> list[str]:
|
|
return sorted(self._providers.keys())
|
|
|
|
def statuses(self) -> list[dict[str, Any]]:
|
|
return [
|
|
{
|
|
"name": provider.name,
|
|
"model": provider.model_name(),
|
|
"available": provider.is_available(),
|
|
"reason": provider.unavailable_reason(),
|
|
"capabilities": {
|
|
"tool_calling": provider.capabilities.tool_calling,
|
|
"web_search": provider.capabilities.web_search,
|
|
"oauth_auth": provider.capabilities.oauth_auth,
|
|
},
|
|
}
|
|
for provider in self._providers.values()
|
|
]
|
|
|
|
|
|
class ModelRouter:
|
|
def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None:
|
|
self.config = config
|
|
self.registry = registry
|
|
|
|
def _candidate_names(self, preferred_provider: str | None) -> list[str]:
|
|
names: list[str] = []
|
|
for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]:
|
|
if name and name not in names:
|
|
names.append(name)
|
|
for name in self.registry.list_names():
|
|
if name not in names:
|
|
names.append(name)
|
|
return names
|
|
|
|
def _supports_request(
|
|
self,
|
|
provider: BaseModelProvider,
|
|
completion_request: CompletionRequest,
|
|
) -> bool:
|
|
if completion_request.require_tools and not provider.capabilities.tool_calling:
|
|
return False
|
|
return True
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> CompletionResponse:
|
|
attempted: list[str] = []
|
|
reasons: list[str] = []
|
|
for name in self._candidate_names(completion_request.preferred_provider):
|
|
provider = self.registry.get(name)
|
|
if not provider:
|
|
reasons.append(f"{name}: unknown provider")
|
|
continue
|
|
attempted.append(name)
|
|
if not provider.is_available():
|
|
reason = provider.unavailable_reason() or "provider unavailable"
|
|
reasons.append(f"{name}: {reason}")
|
|
continue
|
|
if not self._supports_request(provider, completion_request):
|
|
reasons.append(f"{name}: missing required capabilities")
|
|
continue
|
|
selection_reason = "selected preferred provider"
|
|
if completion_request.preferred_provider and name != completion_request.preferred_provider:
|
|
selection_reason = f"preferred provider unavailable, fell back to {name}"
|
|
elif not completion_request.preferred_provider and name != self.config.default_provider:
|
|
selection_reason = f"default provider unavailable, fell back to {name}"
|
|
elif name == self.config.default_provider:
|
|
selection_reason = "selected default provider"
|
|
try:
|
|
payload = provider.complete(completion_request)
|
|
except ModelProviderError as exc:
|
|
reasons.append(f"{name}: request failed: {exc}")
|
|
continue
|
|
return CompletionResponse(
|
|
provider_name=name,
|
|
model_name=provider.model_name(),
|
|
payload=payload,
|
|
selection_reason=selection_reason,
|
|
attempted=attempted,
|
|
)
|
|
details = "; ".join(reasons) if reasons else "no providers registered"
|
|
raise ModelProviderError(f"No model provider available: {details}")
|
|
|
|
|
|
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
|
|
providers: list[BaseModelProvider] = [
|
|
QwenModelProvider(config, oauth),
|
|
UnavailableModelProvider(
|
|
config,
|
|
name="gigachat",
|
|
model_name=config.gigachat_model,
|
|
reason="GigaChat provider is not implemented yet",
|
|
capabilities=ProviderCapabilities(),
|
|
),
|
|
UnavailableModelProvider(
|
|
config,
|
|
name="yandexgpt",
|
|
model_name=config.yandexgpt_model,
|
|
reason="YandexGPT provider is not implemented yet",
|
|
capabilities=ProviderCapabilities(),
|
|
),
|
|
]
|
|
return ProviderRegistry(providers)
|