Refactor server for multi-provider routing
This commit is contained in:
parent
84a6e8b5d8
commit
7fac6fa41e
22
README.md
22
README.md
|
|
@ -1,8 +1,8 @@
|
|||
# new-qwen
|
||||
|
||||
Клиент-серверная замена локального агента `qwen-code`.
|
||||
Клиент-серверная замена локального агента `qwen-code` с серверной оркестрацией моделей.
|
||||
|
||||
- `serv` отвечает за OAuth, сессии, работу с Qwen LLM и вызов инструментов.
|
||||
- `serv` отвечает за OAuth, сессии, маршрутизацию к model provider-ам и вызов инструментов.
|
||||
- `bot` отвечает за Telegram и пересылку сообщений на сервер.
|
||||
|
||||
Проект написан на Python stdlib, чтобы не зависеть от Node/npm в текущем окружении.
|
||||
|
|
@ -22,6 +22,12 @@ serv/app.py
|
|||
Qwen OAuth + OpenAI-compatible endpoint
|
||||
```
|
||||
|
||||
На стороне `serv` теперь есть отдельные слои:
|
||||
|
||||
- `model_router.py` - выбор провайдера и fallback policy
|
||||
- `llm.py` - агентный цикл поверх абстракции провайдера
|
||||
- `oauth.py` - auth only для Qwen path
|
||||
|
||||
## Что уже реализовано
|
||||
|
||||
- Qwen OAuth Device Flow, совместимый с `qwen-code`
|
||||
|
|
@ -39,6 +45,8 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
|
||||
- live approval flow для инструментов через Telegram
|
||||
- provider-based web search с приоритетом DashScope через Qwen OAuth
|
||||
- model router с `qwen` как первым провайдером и fallback-ready архитектурой для `gigachat` и `yandexgpt`
|
||||
- router умеет fallback не только по конфигу, но и при runtime-ошибке провайдера
|
||||
|
||||
## Ограничения текущей реализации
|
||||
|
||||
|
|
@ -67,6 +75,10 @@ cp serv/.env.example serv/.env
|
|||
Ключевые параметры сервера:
|
||||
|
||||
- `NEW_QWEN_STATE_DIR` - где хранить jobs и pending OAuth flows
|
||||
- `NEW_QWEN_DEFAULT_PROVIDER` - основной model provider, сейчас по умолчанию `qwen`
|
||||
- `NEW_QWEN_FALLBACK_PROVIDERS` - fallback-цепочка провайдеров через запятую
|
||||
- `NEW_QWEN_GIGACHAT_MODEL` - имя модели для будущего GigaChat adapter
|
||||
- `NEW_QWEN_YANDEXGPT_MODEL` - имя модели для будущего YandexGPT adapter
|
||||
- `NEW_QWEN_TOOL_POLICY` - режим инструментов:
|
||||
`full-access` - все инструменты
|
||||
`workspace-write` - без `exec_command`
|
||||
|
|
@ -130,6 +142,12 @@ curl -X POST http://127.0.0.1:8080/api/v1/auth/device/start
|
|||
- `POST /api/v1/chat/cancel`
|
||||
- `POST /api/v1/approval/respond`
|
||||
|
||||
`GET /api/v1/auth/status` теперь также показывает:
|
||||
|
||||
- `default_provider`
|
||||
- `fallback_providers`
|
||||
- список `providers` с availability и capabilities
|
||||
|
||||
## Telegram Approval Flow
|
||||
|
||||
Если политика инструментов настроена как `ask-shell`, `ask-write` или `ask-all`, бот пришлёт запрос на подтверждение с `approval_id`.
|
||||
|
|
|
|||
|
|
@ -68,6 +68,12 @@ def summarize_event(event: dict[str, Any]) -> str | None:
|
|||
if event_type == "job_status":
|
||||
return event.get("message")
|
||||
if event_type == "model_request":
|
||||
provider = event.get("provider")
|
||||
model = event.get("model")
|
||||
if provider and model:
|
||||
return f"Думаю над ответом через {provider}/{model}"
|
||||
if provider:
|
||||
return f"Думаю над ответом через {provider}"
|
||||
return "Думаю над ответом"
|
||||
if event_type == "tool_call":
|
||||
return f"Вызываю инструмент: {event.get('name')}"
|
||||
|
|
@ -555,6 +561,8 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
|||
chat_id,
|
||||
"Сервер доступен.\n"
|
||||
f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n"
|
||||
f"default_provider: {status.get('default_provider')}\n"
|
||||
f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n"
|
||||
f"resource_url: {status.get('resource_url')}\n"
|
||||
f"expires_at: {status.get('expires_at')}\n"
|
||||
f"tool_policy: {status.get('tool_policy')}\n"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
NEW_QWEN_HOST=127.0.0.1
|
||||
NEW_QWEN_PORT=8080
|
||||
NEW_QWEN_MODEL=qwen3.6-plus
|
||||
NEW_QWEN_DEFAULT_PROVIDER=qwen
|
||||
NEW_QWEN_FALLBACK_PROVIDERS=
|
||||
NEW_QWEN_GIGACHAT_MODEL=GigaChat
|
||||
NEW_QWEN_YANDEXGPT_MODEL=yandexgpt
|
||||
NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git
|
||||
NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions
|
||||
NEW_QWEN_STATE_DIR=/home/mirivlad/git/new-qwen/.new-qwen/state
|
||||
|
|
|
|||
40
serv/app.py
40
serv/app.py
|
|
@ -13,6 +13,7 @@ from config import ServerConfig
|
|||
from approvals import ApprovalStore
|
||||
from jobs import JobStore
|
||||
from llm import QwenAgent
|
||||
from model_router import ModelProviderError, ModelRouter, build_provider_registry
|
||||
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
||||
from sessions import SessionStore
|
||||
from tools import ToolRegistry
|
||||
|
|
@ -23,8 +24,10 @@ class AppState:
|
|||
self.config = config
|
||||
self.oauth = QwenOAuthManager()
|
||||
self.sessions = SessionStore(config.session_dir)
|
||||
self.providers = build_provider_registry(config, self.oauth)
|
||||
self.router = ModelRouter(config, self.providers)
|
||||
self.tools = ToolRegistry(config, self.oauth)
|
||||
self.agent = QwenAgent(config, self.oauth, self.tools)
|
||||
self.agent = QwenAgent(config, self.router, self.tools)
|
||||
self.jobs = JobStore(
|
||||
config.state_dir / "jobs",
|
||||
retention_seconds=config.jobs_retention_seconds,
|
||||
|
|
@ -103,12 +106,18 @@ class AppState:
|
|||
if not creds:
|
||||
return {
|
||||
"authenticated": False,
|
||||
"default_provider": self.config.default_provider,
|
||||
"fallback_providers": self.config.fallback_providers,
|
||||
"providers": self.providers.statuses(),
|
||||
"tool_policy": self.config.tool_policy,
|
||||
"pending_flows": len(self.pending_device_flows),
|
||||
"pending_approvals": len(self.approvals.list_pending()),
|
||||
}
|
||||
return {
|
||||
"authenticated": True,
|
||||
"default_provider": self.config.default_provider,
|
||||
"fallback_providers": self.config.fallback_providers,
|
||||
"providers": self.providers.statuses(),
|
||||
"resource_url": creds.get("resource_url"),
|
||||
"expires_at": creds.get("expiry_date"),
|
||||
"tool_policy": self.config.tool_policy,
|
||||
|
|
@ -154,7 +163,14 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
return
|
||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||
|
||||
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
|
||||
def _run_chat_job(
|
||||
self,
|
||||
job_id: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
preferred_provider: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
if self.app.jobs.is_cancel_requested(job_id):
|
||||
reason = "Job canceled before execution started"
|
||||
|
|
@ -181,6 +197,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
arguments,
|
||||
),
|
||||
is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id),
|
||||
preferred_provider=preferred_provider,
|
||||
)
|
||||
if self.app.jobs.is_cancel_requested(job_id):
|
||||
reason = "Job canceled by operator"
|
||||
|
|
@ -209,6 +226,8 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
job_id,
|
||||
answer=result["answer"],
|
||||
usage=result.get("usage"),
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
except Exception as exc:
|
||||
if self.app.jobs.is_cancel_requested(job_id):
|
||||
|
|
@ -299,9 +318,14 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
session_id = body.get("session_id") or uuid.uuid4().hex
|
||||
user_id = str(body.get("user_id") or "anonymous")
|
||||
message = body["message"]
|
||||
preferred_provider = body.get("provider")
|
||||
session = self.app.sessions.load(session_id)
|
||||
history = session.get("messages", [])
|
||||
result = self.app.agent.run(history, message)
|
||||
result = self.app.agent.run(
|
||||
history,
|
||||
message,
|
||||
preferred_provider=preferred_provider,
|
||||
)
|
||||
persisted_messages = result["messages"][1:]
|
||||
self.app.sessions.save(
|
||||
session_id,
|
||||
|
|
@ -320,6 +344,8 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
"answer": result["answer"],
|
||||
"events": result["events"],
|
||||
"usage": result.get("usage"),
|
||||
"provider": result.get("provider"),
|
||||
"model": result.get("model"),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
|
@ -329,10 +355,11 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
session_id = body.get("session_id") or uuid.uuid4().hex
|
||||
user_id = str(body.get("user_id") or "anonymous")
|
||||
message = body["message"]
|
||||
preferred_provider = body.get("provider")
|
||||
job = self.app.jobs.create(session_id, user_id, message)
|
||||
thread = threading.Thread(
|
||||
target=self._run_chat_job,
|
||||
args=(job["job_id"], session_id, user_id, message),
|
||||
args=(job["job_id"], session_id, user_id, message, preferred_provider),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
|
@ -342,6 +369,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
"job_id": job["job_id"],
|
||||
"session_id": session_id,
|
||||
"status": "queued",
|
||||
"provider": preferred_provider or self.app.config.default_provider,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
|
@ -367,6 +395,8 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
"answer": job.get("answer"),
|
||||
"usage": job.get("usage"),
|
||||
"error": job.get("error"),
|
||||
"provider": job.get("provider"),
|
||||
"model": job.get("model"),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
|
@ -427,7 +457,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
return
|
||||
|
||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||
except OAuthError as exc:
|
||||
except (OAuthError, ModelProviderError) as exc:
|
||||
self._send(HTTPStatus.BAD_GATEWAY, {"error": str(exc)})
|
||||
except KeyError as exc:
|
||||
self._send(HTTPStatus.BAD_REQUEST, {"error": f"Missing field: {exc}"})
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ class ServerConfig:
|
|||
host: str
|
||||
port: int
|
||||
model: str
|
||||
default_provider: str
|
||||
fallback_providers: list[str]
|
||||
gigachat_model: str
|
||||
yandexgpt_model: str
|
||||
workspace_root: Path
|
||||
session_dir: Path
|
||||
state_dir: Path
|
||||
|
|
@ -59,6 +63,14 @@ class ServerConfig:
|
|||
host=os.environ.get("NEW_QWEN_HOST", "127.0.0.1"),
|
||||
port=int(os.environ.get("NEW_QWEN_PORT", "8080")),
|
||||
model=os.environ.get("NEW_QWEN_MODEL", "qwen3.6-plus"),
|
||||
default_provider=os.environ.get("NEW_QWEN_DEFAULT_PROVIDER", "qwen").strip() or "qwen",
|
||||
fallback_providers=[
|
||||
item.strip()
|
||||
for item in os.environ.get("NEW_QWEN_FALLBACK_PROVIDERS", "").split(",")
|
||||
if item.strip()
|
||||
],
|
||||
gigachat_model=os.environ.get("NEW_QWEN_GIGACHAT_MODEL", "GigaChat").strip(),
|
||||
yandexgpt_model=os.environ.get("NEW_QWEN_YANDEXGPT_MODEL", "yandexgpt").strip(),
|
||||
workspace_root=workspace_root.resolve(),
|
||||
session_dir=session_dir.resolve(),
|
||||
state_dir=state_dir.resolve(),
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ class JobStore:
|
|||
"events": [],
|
||||
"answer": None,
|
||||
"usage": None,
|
||||
"provider": None,
|
||||
"model": None,
|
||||
"error": None,
|
||||
"cancel_requested": False,
|
||||
"cancel_actor": None,
|
||||
|
|
@ -141,12 +143,16 @@ class JobStore:
|
|||
*,
|
||||
answer: str,
|
||||
usage: dict[str, Any] | None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
with self._lock:
|
||||
job = self._jobs[job_id]
|
||||
job["status"] = "completed"
|
||||
job["answer"] = answer
|
||||
job["usage"] = usage
|
||||
job["provider"] = provider
|
||||
job["model"] = model
|
||||
job["updated_at"] = time.time()
|
||||
self._save_job(job)
|
||||
|
||||
|
|
|
|||
64
serv/llm.py
64
serv/llm.py
|
|
@ -2,10 +2,9 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
from typing import Any, Callable
|
||||
from urllib import error, request
|
||||
|
||||
from config import ServerConfig
|
||||
from oauth import OAuthError, QwenOAuthManager
|
||||
from model_router import CompletionRequest, ModelRouter
|
||||
from tools import ToolError, ToolRegistry
|
||||
|
||||
|
||||
|
|
@ -18,37 +17,11 @@ Do not claim to have executed tools unless a tool result confirms it."""
|
|||
|
||||
|
||||
class QwenAgent:
|
||||
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager, tools: ToolRegistry) -> None:
|
||||
def __init__(self, config: ServerConfig, router: ModelRouter, tools: ToolRegistry) -> None:
|
||||
self.config = config
|
||||
self.oauth = oauth
|
||||
self.router = router
|
||||
self.tools = tools
|
||||
|
||||
def _request_completion(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
creds = self.oauth.get_valid_credentials()
|
||||
base_url = self.oauth.get_openai_base_url(creds)
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"tools": self.tools.schemas(),
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = request.Request(
|
||||
f"{base_url}/chat/completions",
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {creds['access_token']}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=180) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
except error.HTTPError as exc:
|
||||
body = exc.read().decode("utf-8", errors="replace")
|
||||
raise OAuthError(f"LLM request failed with HTTP {exc.code}: {body}") from exc
|
||||
|
||||
def run(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
|
|
@ -56,6 +29,7 @@ class QwenAgent:
|
|||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
||||
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
|
||||
is_cancelled: Callable[[], bool] | None = None,
|
||||
preferred_provider: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
emit = on_event or (lambda _event: None)
|
||||
cancel_check = is_cancelled or (lambda: False)
|
||||
|
|
@ -69,11 +43,33 @@ class QwenAgent:
|
|||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
events: list[dict[str, Any]] = []
|
||||
selected_provider: str | None = None
|
||||
selected_model: str | None = None
|
||||
|
||||
for _ in range(self.config.max_tool_rounds):
|
||||
ensure_not_cancelled()
|
||||
emit({"type": "model_request", "message": "Запрашиваю ответ модели"})
|
||||
response = self._request_completion(messages)
|
||||
completion = self.router.complete(
|
||||
CompletionRequest(
|
||||
messages=messages,
|
||||
tools=self.tools.schemas(),
|
||||
tool_choice="auto",
|
||||
preferred_provider=preferred_provider,
|
||||
require_tools=True,
|
||||
)
|
||||
)
|
||||
emit(
|
||||
{
|
||||
"type": "model_request",
|
||||
"message": "Запрашиваю ответ модели",
|
||||
"provider": completion.provider_name,
|
||||
"model": completion.model_name,
|
||||
"selection_reason": completion.selection_reason,
|
||||
"attempted": completion.attempted,
|
||||
}
|
||||
)
|
||||
selected_provider = completion.provider_name
|
||||
selected_model = completion.model_name
|
||||
response = completion.payload
|
||||
ensure_not_cancelled()
|
||||
choice = response["choices"][0]["message"]
|
||||
tool_calls = choice.get("tool_calls") or []
|
||||
|
|
@ -88,6 +84,8 @@ class QwenAgent:
|
|||
"answer": content or "",
|
||||
"events": events,
|
||||
"usage": response.get("usage"),
|
||||
"provider": selected_provider,
|
||||
"model": selected_model,
|
||||
"messages": messages + [{"role": "assistant", "content": content or ""}],
|
||||
}
|
||||
|
||||
|
|
@ -164,5 +162,7 @@ class QwenAgent:
|
|||
"answer": final_message,
|
||||
"events": events,
|
||||
"usage": None,
|
||||
"provider": selected_provider,
|
||||
"model": selected_model,
|
||||
"messages": messages + [{"role": "assistant", "content": final_message}],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,254 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from urllib import error, request
|
||||
|
||||
from config import ServerConfig
|
||||
from oauth import OAuthError, QwenOAuthManager
|
||||
|
||||
|
||||
class ModelProviderError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderCapabilities:
|
||||
tool_calling: bool = False
|
||||
web_search: bool = False
|
||||
oauth_auth: bool = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CompletionRequest:
|
||||
messages: list[dict[str, Any]]
|
||||
tools: list[dict[str, Any]]
|
||||
tool_choice: str = "auto"
|
||||
preferred_provider: str | None = None
|
||||
require_tools: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CompletionResponse:
|
||||
provider_name: str
|
||||
model_name: str
|
||||
payload: dict[str, Any]
|
||||
selection_reason: str
|
||||
attempted: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class BaseModelProvider:
|
||||
name = "base"
|
||||
capabilities = ProviderCapabilities()
|
||||
|
||||
def __init__(self, config: ServerConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def is_available(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def unavailable_reason(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def model_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnavailableModelProvider(BaseModelProvider):
|
||||
def __init__(
|
||||
self,
|
||||
config: ServerConfig,
|
||||
*,
|
||||
name: str,
|
||||
model_name: str,
|
||||
reason: str,
|
||||
capabilities: ProviderCapabilities | None = None,
|
||||
) -> None:
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
self._model_name = model_name
|
||||
self._reason = reason
|
||||
if capabilities is not None:
|
||||
self.capabilities = capabilities
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return False
|
||||
|
||||
def unavailable_reason(self) -> str | None:
|
||||
return self._reason
|
||||
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||
raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}")
|
||||
|
||||
|
||||
class QwenModelProvider(BaseModelProvider):
|
||||
name = "qwen"
|
||||
capabilities = ProviderCapabilities(
|
||||
tool_calling=True,
|
||||
web_search=True,
|
||||
oauth_auth=True,
|
||||
)
|
||||
|
||||
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
|
||||
super().__init__(config)
|
||||
self.oauth = oauth
|
||||
|
||||
def is_available(self) -> bool:
|
||||
creds = self.oauth.load_credentials()
|
||||
return bool(creds and creds.get("access_token"))
|
||||
|
||||
def unavailable_reason(self) -> str | None:
|
||||
if self.is_available():
|
||||
return None
|
||||
return "Qwen OAuth is not configured"
|
||||
|
||||
def model_name(self) -> str:
|
||||
return self.config.model
|
||||
|
||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||
creds = self.oauth.get_valid_credentials()
|
||||
base_url = self.oauth.get_openai_base_url(creds)
|
||||
payload = {
|
||||
"model": self.model_name(),
|
||||
"messages": completion_request.messages,
|
||||
"tools": completion_request.tools,
|
||||
"tool_choice": completion_request.tool_choice,
|
||||
}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = request.Request(
|
||||
f"{base_url}/chat/completions",
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {creds['access_token']}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=180) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
except error.HTTPError as exc:
|
||||
body = exc.read().decode("utf-8", errors="replace")
|
||||
raise ModelProviderError(
|
||||
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
|
||||
) from exc
|
||||
except OAuthError as exc:
|
||||
raise ModelProviderError(str(exc)) from exc
|
||||
|
||||
|
||||
class ProviderRegistry:
|
||||
def __init__(self, providers: list[BaseModelProvider]) -> None:
|
||||
self._providers = {provider.name: provider for provider in providers}
|
||||
|
||||
def get(self, name: str) -> BaseModelProvider | None:
|
||||
return self._providers.get(name)
|
||||
|
||||
def list_names(self) -> list[str]:
|
||||
return sorted(self._providers.keys())
|
||||
|
||||
def statuses(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"name": provider.name,
|
||||
"model": provider.model_name(),
|
||||
"available": provider.is_available(),
|
||||
"reason": provider.unavailable_reason(),
|
||||
"capabilities": {
|
||||
"tool_calling": provider.capabilities.tool_calling,
|
||||
"web_search": provider.capabilities.web_search,
|
||||
"oauth_auth": provider.capabilities.oauth_auth,
|
||||
},
|
||||
}
|
||||
for provider in self._providers.values()
|
||||
]
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None:
|
||||
self.config = config
|
||||
self.registry = registry
|
||||
|
||||
def _candidate_names(self, preferred_provider: str | None) -> list[str]:
|
||||
names: list[str] = []
|
||||
for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]:
|
||||
if name and name not in names:
|
||||
names.append(name)
|
||||
for name in self.registry.list_names():
|
||||
if name not in names:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def _supports_request(
|
||||
self,
|
||||
provider: BaseModelProvider,
|
||||
completion_request: CompletionRequest,
|
||||
) -> bool:
|
||||
if completion_request.require_tools and not provider.capabilities.tool_calling:
|
||||
return False
|
||||
return True
|
||||
|
||||
def complete(self, completion_request: CompletionRequest) -> CompletionResponse:
|
||||
attempted: list[str] = []
|
||||
reasons: list[str] = []
|
||||
for name in self._candidate_names(completion_request.preferred_provider):
|
||||
provider = self.registry.get(name)
|
||||
if not provider:
|
||||
reasons.append(f"{name}: unknown provider")
|
||||
continue
|
||||
attempted.append(name)
|
||||
if not provider.is_available():
|
||||
reason = provider.unavailable_reason() or "provider unavailable"
|
||||
reasons.append(f"{name}: {reason}")
|
||||
continue
|
||||
if not self._supports_request(provider, completion_request):
|
||||
reasons.append(f"{name}: missing required capabilities")
|
||||
continue
|
||||
selection_reason = "selected preferred provider"
|
||||
if completion_request.preferred_provider and name != completion_request.preferred_provider:
|
||||
selection_reason = f"preferred provider unavailable, fell back to {name}"
|
||||
elif not completion_request.preferred_provider and name != self.config.default_provider:
|
||||
selection_reason = f"default provider unavailable, fell back to {name}"
|
||||
elif name == self.config.default_provider:
|
||||
selection_reason = "selected default provider"
|
||||
try:
|
||||
payload = provider.complete(completion_request)
|
||||
except ModelProviderError as exc:
|
||||
reasons.append(f"{name}: request failed: {exc}")
|
||||
continue
|
||||
return CompletionResponse(
|
||||
provider_name=name,
|
||||
model_name=provider.model_name(),
|
||||
payload=payload,
|
||||
selection_reason=selection_reason,
|
||||
attempted=attempted,
|
||||
)
|
||||
details = "; ".join(reasons) if reasons else "no providers registered"
|
||||
raise ModelProviderError(f"No model provider available: {details}")
|
||||
|
||||
|
||||
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
|
||||
providers: list[BaseModelProvider] = [
|
||||
QwenModelProvider(config, oauth),
|
||||
UnavailableModelProvider(
|
||||
config,
|
||||
name="gigachat",
|
||||
model_name=config.gigachat_model,
|
||||
reason="GigaChat provider is not implemented yet",
|
||||
capabilities=ProviderCapabilities(),
|
||||
),
|
||||
UnavailableModelProvider(
|
||||
config,
|
||||
name="yandexgpt",
|
||||
model_name=config.yandexgpt_model,
|
||||
reason="YandexGPT provider is not implemented yet",
|
||||
capabilities=ProviderCapabilities(),
|
||||
),
|
||||
]
|
||||
return ProviderRegistry(providers)
|
||||
Loading…
Reference in New Issue