431 lines
16 KiB
Python
431 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
import uuid
|
|
from urllib import error, request
|
|
|
|
from config import ServerConfig
|
|
from gigachat import GigaChatAuthManager, GigaChatError
|
|
from oauth import OAuthError, QwenOAuthManager, QWEN_OAUTH_ALLOWED_MODELS
|
|
|
|
|
|
class ModelProviderError(RuntimeError):
|
|
pass
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ProviderCapabilities:
|
|
tool_calling: bool = False
|
|
web_search: bool = False
|
|
oauth_auth: bool = False
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CompletionRequest:
|
|
messages: list[dict[str, Any]]
|
|
tools: list[dict[str, Any]]
|
|
tool_choice: str = "auto"
|
|
preferred_provider: str | None = None
|
|
require_tools: bool = True
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CompletionResponse:
|
|
provider_name: str
|
|
model_name: str
|
|
payload: dict[str, Any]
|
|
selection_reason: str
|
|
attempted: list[str] = field(default_factory=list)
|
|
|
|
|
|
class BaseModelProvider:
|
|
name = "base"
|
|
capabilities = ProviderCapabilities()
|
|
|
|
def __init__(self, config: ServerConfig) -> None:
|
|
self.config = config
|
|
|
|
def is_available(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
raise NotImplementedError
|
|
|
|
def model_name(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class UnavailableModelProvider(BaseModelProvider):
|
|
def __init__(
|
|
self,
|
|
config: ServerConfig,
|
|
*,
|
|
name: str,
|
|
model_name: str,
|
|
reason: str,
|
|
capabilities: ProviderCapabilities | None = None,
|
|
) -> None:
|
|
super().__init__(config)
|
|
self.name = name
|
|
self._model_name = model_name
|
|
self._reason = reason
|
|
if capabilities is not None:
|
|
self.capabilities = capabilities
|
|
|
|
def is_available(self) -> bool:
|
|
return False
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
return self._reason
|
|
|
|
def model_name(self) -> str:
|
|
return self._model_name
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}")
|
|
|
|
|
|
class QwenModelProvider(BaseModelProvider):
|
|
name = "qwen"
|
|
capabilities = ProviderCapabilities(
|
|
tool_calling=True,
|
|
web_search=True,
|
|
oauth_auth=True,
|
|
)
|
|
|
|
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
|
|
super().__init__(config)
|
|
self.oauth = oauth
|
|
# Resolve model ID to actual model name
|
|
self._model_id = config.model if config.model in QWEN_OAUTH_ALLOWED_MODELS else "coder-model"
|
|
self._model_name = oauth.get_model_name_for_id(self._model_id)
|
|
|
|
def is_available(self) -> bool:
|
|
creds = self.oauth.load_credentials()
|
|
return bool(creds and creds.get("access_token"))
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
if self.is_available():
|
|
return None
|
|
return "Qwen OAuth is not configured"
|
|
|
|
def model_name(self) -> str:
|
|
return self._model_name
|
|
|
|
@staticmethod
|
|
def _normalize_content(value: Any) -> list[dict[str, str]]:
|
|
if isinstance(value, list):
|
|
normalized: list[dict[str, str]] = []
|
|
for item in value:
|
|
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
|
|
normalized.append({"type": "text", "text": item["text"]})
|
|
if normalized:
|
|
return normalized
|
|
text = "" if value is None else str(value)
|
|
return [{"type": "text", "text": text}]
|
|
|
|
def _normalize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
normalized: list[dict[str, Any]] = []
|
|
for message in messages:
|
|
item = {k: v for k, v in message.items() if k != "content"}
|
|
item["content"] = self._normalize_content(message.get("content"))
|
|
normalized.append(item)
|
|
return normalized
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
creds = self.oauth.get_valid_credentials()
|
|
base_url = self.oauth.get_openai_base_url(creds)
|
|
payload = {
|
|
"model": self.model_name(),
|
|
"messages": self._normalize_messages(completion_request.messages),
|
|
"max_tokens": 8000,
|
|
"metadata": {
|
|
"sessionId": str(uuid.uuid4()),
|
|
"promptId": uuid.uuid4().hex[:12],
|
|
},
|
|
"vl_high_resolution_images": True,
|
|
}
|
|
if completion_request.tools:
|
|
payload["tools"] = completion_request.tools
|
|
data = json.dumps(payload).encode("utf-8")
|
|
|
|
# Add DashScope-specific headers for OAuth tokens
|
|
user_agent = "QwenCode/unknown (linux; x64)"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {creds['access_token']}",
|
|
"User-Agent": user_agent,
|
|
"X-DashScope-CacheControl": "enable",
|
|
"X-DashScope-UserAgent": user_agent,
|
|
"X-DashScope-AuthType": "qwen-oauth",
|
|
"Accept": "application/json",
|
|
}
|
|
req = request.Request(
|
|
f"{base_url}/chat/completions",
|
|
data=data,
|
|
headers=headers,
|
|
method="POST",
|
|
)
|
|
try:
|
|
with request.urlopen(req, timeout=180) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
except error.HTTPError as exc:
|
|
body = exc.read().decode("utf-8", errors="replace")
|
|
raise ModelProviderError(
|
|
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
|
|
) from exc
|
|
except OAuthError as exc:
|
|
raise ModelProviderError(str(exc)) from exc
|
|
|
|
|
|
class GigaChatModelProvider(BaseModelProvider):
|
|
name = "gigachat"
|
|
capabilities = ProviderCapabilities(
|
|
tool_calling=True,
|
|
web_search=False,
|
|
oauth_auth=False,
|
|
)
|
|
|
|
def __init__(self, config: ServerConfig, auth: GigaChatAuthManager) -> None:
|
|
super().__init__(config)
|
|
self.auth = auth
|
|
|
|
def is_available(self) -> bool:
|
|
return self.auth.is_configured()
|
|
|
|
def unavailable_reason(self) -> str | None:
|
|
if self.is_available():
|
|
return None
|
|
return "GigaChat auth key is not configured"
|
|
|
|
def model_name(self) -> str:
|
|
return self.config.gigachat_model
|
|
|
|
@staticmethod
|
|
def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
|
|
return dict(tool.get("function") or {})
|
|
|
|
def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
converted: list[dict[str, Any]] = []
|
|
for message in messages:
|
|
role = message.get("role")
|
|
if role in {"system", "user"}:
|
|
converted.append(
|
|
{
|
|
"role": role,
|
|
"content": message.get("content", ""),
|
|
}
|
|
)
|
|
continue
|
|
if role == "assistant":
|
|
payload: dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": message.get("content", ""),
|
|
}
|
|
tool_calls = message.get("tool_calls") or []
|
|
if tool_calls:
|
|
first_call = tool_calls[0]
|
|
raw_arguments = first_call.get("function", {}).get("arguments", "{}")
|
|
if isinstance(raw_arguments, str):
|
|
try:
|
|
arguments = json.loads(raw_arguments)
|
|
except json.JSONDecodeError:
|
|
arguments = {"raw": raw_arguments}
|
|
else:
|
|
arguments = raw_arguments
|
|
payload["function_call"] = {
|
|
"name": first_call.get("function", {}).get("name"),
|
|
"arguments": arguments,
|
|
}
|
|
if message.get("functions_state_id"):
|
|
payload["functions_state_id"] = message["functions_state_id"]
|
|
converted.append(payload)
|
|
continue
|
|
if role == "tool":
|
|
converted.append(
|
|
{
|
|
"role": "function",
|
|
"name": message.get("name") or "tool_result",
|
|
"content": message.get("content", ""),
|
|
}
|
|
)
|
|
return converted
|
|
|
|
def _normalize_response(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
choices = payload.get("choices") or []
|
|
if not choices:
|
|
return payload
|
|
choice = choices[0]
|
|
message = dict(choice.get("message") or {})
|
|
normalized_message: dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": message.get("content", "") or "",
|
|
}
|
|
function_call = message.get("function_call")
|
|
if function_call:
|
|
arguments = function_call.get("arguments", {})
|
|
if not isinstance(arguments, str):
|
|
arguments = json.dumps(arguments, ensure_ascii=False)
|
|
normalized_message["tool_calls"] = [
|
|
{
|
|
"id": uuid.uuid4().hex,
|
|
"type": "function",
|
|
"function": {
|
|
"name": function_call.get("name"),
|
|
"arguments": arguments,
|
|
},
|
|
}
|
|
]
|
|
if message.get("functions_state_id"):
|
|
normalized_message["functions_state_id"] = message.get("functions_state_id")
|
|
choice["message"] = normalized_message
|
|
payload["choices"] = choices
|
|
return payload
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
|
try:
|
|
access_token = self.auth.get_valid_token()
|
|
except GigaChatError as exc:
|
|
raise ModelProviderError(str(exc)) from exc
|
|
api_base = self.config.gigachat_api_base_url.rstrip("/")
|
|
payload: dict[str, Any] = {
|
|
"model": self.model_name(),
|
|
"messages": self._convert_messages(completion_request.messages),
|
|
}
|
|
if completion_request.tools:
|
|
payload["functions"] = [
|
|
self._convert_tool_schema(tool)
|
|
for tool in completion_request.tools
|
|
]
|
|
payload["function_call"] = "auto"
|
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
req = request.Request(
|
|
f"{api_base}/chat/completions",
|
|
data=data,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
"Authorization": f"Bearer {access_token}",
|
|
},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with request.urlopen(req, timeout=180) as response:
|
|
raw = json.loads(response.read().decode("utf-8"))
|
|
except error.HTTPError as exc:
|
|
body = exc.read().decode("utf-8", errors="replace")
|
|
raise ModelProviderError(
|
|
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
|
|
) from exc
|
|
return self._normalize_response(raw)
|
|
|
|
|
|
class ProviderRegistry:
|
|
def __init__(self, providers: list[BaseModelProvider]) -> None:
|
|
self._providers = {provider.name: provider for provider in providers}
|
|
|
|
def get(self, name: str) -> BaseModelProvider | None:
|
|
return self._providers.get(name)
|
|
|
|
def list_names(self) -> list[str]:
|
|
return sorted(self._providers.keys())
|
|
|
|
def statuses(self) -> list[dict[str, Any]]:
|
|
return [
|
|
{
|
|
"name": provider.name,
|
|
"model": provider.model_name(),
|
|
"available": provider.is_available(),
|
|
"reason": provider.unavailable_reason(),
|
|
"capabilities": {
|
|
"tool_calling": provider.capabilities.tool_calling,
|
|
"web_search": provider.capabilities.web_search,
|
|
"oauth_auth": provider.capabilities.oauth_auth,
|
|
},
|
|
}
|
|
for provider in self._providers.values()
|
|
]
|
|
|
|
|
|
class ModelRouter:
|
|
def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None:
|
|
self.config = config
|
|
self.registry = registry
|
|
|
|
def _candidate_names(self, preferred_provider: str | None) -> list[str]:
|
|
names: list[str] = []
|
|
for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]:
|
|
if name and name not in names:
|
|
names.append(name)
|
|
for name in self.registry.list_names():
|
|
if name not in names:
|
|
names.append(name)
|
|
return names
|
|
|
|
def _supports_request(
|
|
self,
|
|
provider: BaseModelProvider,
|
|
completion_request: CompletionRequest,
|
|
) -> bool:
|
|
if completion_request.require_tools and not provider.capabilities.tool_calling:
|
|
return False
|
|
return True
|
|
|
|
def complete(self, completion_request: CompletionRequest) -> CompletionResponse:
|
|
attempted: list[str] = []
|
|
reasons: list[str] = []
|
|
for name in self._candidate_names(completion_request.preferred_provider):
|
|
provider = self.registry.get(name)
|
|
if not provider:
|
|
reasons.append(f"{name}: unknown provider")
|
|
continue
|
|
attempted.append(name)
|
|
if not provider.is_available():
|
|
reason = provider.unavailable_reason() or "provider unavailable"
|
|
reasons.append(f"{name}: {reason}")
|
|
continue
|
|
if not self._supports_request(provider, completion_request):
|
|
reasons.append(f"{name}: missing required capabilities")
|
|
continue
|
|
selection_reason = "selected preferred provider"
|
|
if completion_request.preferred_provider and name != completion_request.preferred_provider:
|
|
selection_reason = f"preferred provider unavailable, fell back to {name}"
|
|
elif not completion_request.preferred_provider and name != self.config.default_provider:
|
|
selection_reason = f"default provider unavailable, fell back to {name}"
|
|
elif name == self.config.default_provider:
|
|
selection_reason = "selected default provider"
|
|
try:
|
|
payload = provider.complete(completion_request)
|
|
except ModelProviderError as exc:
|
|
reasons.append(f"{name}: request failed: {exc}")
|
|
continue
|
|
return CompletionResponse(
|
|
provider_name=name,
|
|
model_name=provider.model_name(),
|
|
payload=payload,
|
|
selection_reason=selection_reason,
|
|
attempted=attempted,
|
|
)
|
|
details = "; ".join(reasons) if reasons else "no providers registered"
|
|
raise ModelProviderError(f"No model provider available: {details}")
|
|
|
|
|
|
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
|
|
providers: list[BaseModelProvider] = [
|
|
QwenModelProvider(config, oauth),
|
|
GigaChatModelProvider(config, GigaChatAuthManager(config)),
|
|
UnavailableModelProvider(
|
|
config,
|
|
name="yandexgpt",
|
|
model_name=config.yandexgpt_model,
|
|
reason="YandexGPT provider is not implemented yet",
|
|
capabilities=ProviderCapabilities(),
|
|
),
|
|
]
|
|
return ProviderRegistry(providers)
|