new-qwen/serv/model_router.py

431 lines
16 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
import uuid
from urllib import error, request
from config import ServerConfig
from gigachat import GigaChatAuthManager, GigaChatError
from oauth import OAuthError, QwenOAuthManager, QWEN_OAUTH_ALLOWED_MODELS
class ModelProviderError(RuntimeError):
pass
@dataclass(slots=True)
class ProviderCapabilities:
tool_calling: bool = False
web_search: bool = False
oauth_auth: bool = False
@dataclass(slots=True)
class CompletionRequest:
messages: list[dict[str, Any]]
tools: list[dict[str, Any]]
tool_choice: str = "auto"
preferred_provider: str | None = None
require_tools: bool = True
@dataclass(slots=True)
class CompletionResponse:
provider_name: str
model_name: str
payload: dict[str, Any]
selection_reason: str
attempted: list[str] = field(default_factory=list)
class BaseModelProvider:
name = "base"
capabilities = ProviderCapabilities()
def __init__(self, config: ServerConfig) -> None:
self.config = config
def is_available(self) -> bool:
raise NotImplementedError
def unavailable_reason(self) -> str | None:
raise NotImplementedError
def model_name(self) -> str:
raise NotImplementedError
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
raise NotImplementedError
class UnavailableModelProvider(BaseModelProvider):
def __init__(
self,
config: ServerConfig,
*,
name: str,
model_name: str,
reason: str,
capabilities: ProviderCapabilities | None = None,
) -> None:
super().__init__(config)
self.name = name
self._model_name = model_name
self._reason = reason
if capabilities is not None:
self.capabilities = capabilities
def is_available(self) -> bool:
return False
def unavailable_reason(self) -> str | None:
return self._reason
def model_name(self) -> str:
return self._model_name
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}")
class QwenModelProvider(BaseModelProvider):
name = "qwen"
capabilities = ProviderCapabilities(
tool_calling=True,
web_search=True,
oauth_auth=True,
)
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
super().__init__(config)
self.oauth = oauth
# Resolve model ID to actual model name
self._model_id = config.model if config.model in QWEN_OAUTH_ALLOWED_MODELS else "coder-model"
self._model_name = oauth.get_model_name_for_id(self._model_id)
def is_available(self) -> bool:
creds = self.oauth.load_credentials()
return bool(creds and creds.get("access_token"))
def unavailable_reason(self) -> str | None:
if self.is_available():
return None
return "Qwen OAuth is not configured"
def model_name(self) -> str:
return self._model_name
@staticmethod
def _normalize_content(value: Any) -> list[dict[str, str]]:
if isinstance(value, list):
normalized: list[dict[str, str]] = []
for item in value:
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
normalized.append({"type": "text", "text": item["text"]})
if normalized:
return normalized
text = "" if value is None else str(value)
return [{"type": "text", "text": text}]
def _normalize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for message in messages:
item = {k: v for k, v in message.items() if k != "content"}
item["content"] = self._normalize_content(message.get("content"))
normalized.append(item)
return normalized
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
creds = self.oauth.get_valid_credentials()
base_url = self.oauth.get_openai_base_url(creds)
payload = {
"model": self.model_name(),
"messages": self._normalize_messages(completion_request.messages),
"max_tokens": 8000,
"metadata": {
"sessionId": str(uuid.uuid4()),
"promptId": uuid.uuid4().hex[:12],
},
"vl_high_resolution_images": True,
}
if completion_request.tools:
payload["tools"] = completion_request.tools
data = json.dumps(payload).encode("utf-8")
# Add DashScope-specific headers for OAuth tokens
user_agent = "QwenCode/unknown (linux; x64)"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {creds['access_token']}",
"User-Agent": user_agent,
"X-DashScope-CacheControl": "enable",
"X-DashScope-UserAgent": user_agent,
"X-DashScope-AuthType": "qwen-oauth",
"Accept": "application/json",
}
req = request.Request(
f"{base_url}/chat/completions",
data=data,
headers=headers,
method="POST",
)
try:
with request.urlopen(req, timeout=180) as response:
return json.loads(response.read().decode("utf-8"))
except error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise ModelProviderError(
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
) from exc
except OAuthError as exc:
raise ModelProviderError(str(exc)) from exc
class GigaChatModelProvider(BaseModelProvider):
name = "gigachat"
capabilities = ProviderCapabilities(
tool_calling=True,
web_search=False,
oauth_auth=False,
)
def __init__(self, config: ServerConfig, auth: GigaChatAuthManager) -> None:
super().__init__(config)
self.auth = auth
def is_available(self) -> bool:
return self.auth.is_configured()
def unavailable_reason(self) -> str | None:
if self.is_available():
return None
return "GigaChat auth key is not configured"
def model_name(self) -> str:
return self.config.gigachat_model
@staticmethod
def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
return dict(tool.get("function") or {})
def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
converted: list[dict[str, Any]] = []
for message in messages:
role = message.get("role")
if role in {"system", "user"}:
converted.append(
{
"role": role,
"content": message.get("content", ""),
}
)
continue
if role == "assistant":
payload: dict[str, Any] = {
"role": "assistant",
"content": message.get("content", ""),
}
tool_calls = message.get("tool_calls") or []
if tool_calls:
first_call = tool_calls[0]
raw_arguments = first_call.get("function", {}).get("arguments", "{}")
if isinstance(raw_arguments, str):
try:
arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
arguments = {"raw": raw_arguments}
else:
arguments = raw_arguments
payload["function_call"] = {
"name": first_call.get("function", {}).get("name"),
"arguments": arguments,
}
if message.get("functions_state_id"):
payload["functions_state_id"] = message["functions_state_id"]
converted.append(payload)
continue
if role == "tool":
converted.append(
{
"role": "function",
"name": message.get("name") or "tool_result",
"content": message.get("content", ""),
}
)
return converted
def _normalize_response(self, payload: dict[str, Any]) -> dict[str, Any]:
choices = payload.get("choices") or []
if not choices:
return payload
choice = choices[0]
message = dict(choice.get("message") or {})
normalized_message: dict[str, Any] = {
"role": "assistant",
"content": message.get("content", "") or "",
}
function_call = message.get("function_call")
if function_call:
arguments = function_call.get("arguments", {})
if not isinstance(arguments, str):
arguments = json.dumps(arguments, ensure_ascii=False)
normalized_message["tool_calls"] = [
{
"id": uuid.uuid4().hex,
"type": "function",
"function": {
"name": function_call.get("name"),
"arguments": arguments,
},
}
]
if message.get("functions_state_id"):
normalized_message["functions_state_id"] = message.get("functions_state_id")
choice["message"] = normalized_message
payload["choices"] = choices
return payload
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
try:
access_token = self.auth.get_valid_token()
except GigaChatError as exc:
raise ModelProviderError(str(exc)) from exc
api_base = self.config.gigachat_api_base_url.rstrip("/")
payload: dict[str, Any] = {
"model": self.model_name(),
"messages": self._convert_messages(completion_request.messages),
}
if completion_request.tools:
payload["functions"] = [
self._convert_tool_schema(tool)
for tool in completion_request.tools
]
payload["function_call"] = "auto"
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
req = request.Request(
f"{api_base}/chat/completions",
data=data,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {access_token}",
},
method="POST",
)
try:
with request.urlopen(req, timeout=180) as response:
raw = json.loads(response.read().decode("utf-8"))
except error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise ModelProviderError(
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
) from exc
return self._normalize_response(raw)
class ProviderRegistry:
def __init__(self, providers: list[BaseModelProvider]) -> None:
self._providers = {provider.name: provider for provider in providers}
def get(self, name: str) -> BaseModelProvider | None:
return self._providers.get(name)
def list_names(self) -> list[str]:
return sorted(self._providers.keys())
def statuses(self) -> list[dict[str, Any]]:
return [
{
"name": provider.name,
"model": provider.model_name(),
"available": provider.is_available(),
"reason": provider.unavailable_reason(),
"capabilities": {
"tool_calling": provider.capabilities.tool_calling,
"web_search": provider.capabilities.web_search,
"oauth_auth": provider.capabilities.oauth_auth,
},
}
for provider in self._providers.values()
]
class ModelRouter:
def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None:
self.config = config
self.registry = registry
def _candidate_names(self, preferred_provider: str | None) -> list[str]:
names: list[str] = []
for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]:
if name and name not in names:
names.append(name)
for name in self.registry.list_names():
if name not in names:
names.append(name)
return names
def _supports_request(
self,
provider: BaseModelProvider,
completion_request: CompletionRequest,
) -> bool:
if completion_request.require_tools and not provider.capabilities.tool_calling:
return False
return True
def complete(self, completion_request: CompletionRequest) -> CompletionResponse:
attempted: list[str] = []
reasons: list[str] = []
for name in self._candidate_names(completion_request.preferred_provider):
provider = self.registry.get(name)
if not provider:
reasons.append(f"{name}: unknown provider")
continue
attempted.append(name)
if not provider.is_available():
reason = provider.unavailable_reason() or "provider unavailable"
reasons.append(f"{name}: {reason}")
continue
if not self._supports_request(provider, completion_request):
reasons.append(f"{name}: missing required capabilities")
continue
selection_reason = "selected preferred provider"
if completion_request.preferred_provider and name != completion_request.preferred_provider:
selection_reason = f"preferred provider unavailable, fell back to {name}"
elif not completion_request.preferred_provider and name != self.config.default_provider:
selection_reason = f"default provider unavailable, fell back to {name}"
elif name == self.config.default_provider:
selection_reason = "selected default provider"
try:
payload = provider.complete(completion_request)
except ModelProviderError as exc:
reasons.append(f"{name}: request failed: {exc}")
continue
return CompletionResponse(
provider_name=name,
model_name=provider.model_name(),
payload=payload,
selection_reason=selection_reason,
attempted=attempted,
)
details = "; ".join(reasons) if reasons else "no providers registered"
raise ModelProviderError(f"No model provider available: {details}")
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
providers: list[BaseModelProvider] = [
QwenModelProvider(config, oauth),
GigaChatModelProvider(config, GigaChatAuthManager(config)),
UnavailableModelProvider(
config,
name="yandexgpt",
model_name=config.yandexgpt_model,
reason="YandexGPT provider is not implemented yet",
capabilities=ProviderCapabilities(),
),
]
return ProviderRegistry(providers)