Refactor server for multi-provider routing
This commit is contained in:
parent
84a6e8b5d8
commit
7fac6fa41e
22
README.md
22
README.md
|
|
@ -1,8 +1,8 @@
|
||||||
# new-qwen
|
# new-qwen
|
||||||
|
|
||||||
Клиент-серверная замена локального агента `qwen-code`.
|
Клиент-серверная замена локального агента `qwen-code` с серверной оркестрацией моделей.
|
||||||
|
|
||||||
- `serv` отвечает за OAuth, сессии, работу с Qwen LLM и вызов инструментов.
|
- `serv` отвечает за OAuth, сессии, маршрутизацию к model provider-ам и вызов инструментов.
|
||||||
- `bot` отвечает за Telegram и пересылку сообщений на сервер.
|
- `bot` отвечает за Telegram и пересылку сообщений на сервер.
|
||||||
|
|
||||||
Проект написан на Python stdlib, чтобы не зависеть от Node/npm в текущем окружении.
|
Проект написан на Python stdlib, чтобы не зависеть от Node/npm в текущем окружении.
|
||||||
|
|
@ -22,6 +22,12 @@ serv/app.py
|
||||||
Qwen OAuth + OpenAI-compatible endpoint
|
Qwen OAuth + OpenAI-compatible endpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
|
На стороне `serv` теперь есть отдельные слои:
|
||||||
|
|
||||||
|
- `model_router.py` - выбор провайдера и fallback policy
|
||||||
|
- `llm.py` - агентный цикл поверх абстракции провайдера
|
||||||
|
- `oauth.py` - auth only для Qwen path
|
||||||
|
|
||||||
## Что уже реализовано
|
## Что уже реализовано
|
||||||
|
|
||||||
- Qwen OAuth Device Flow, совместимый с `qwen-code`
|
- Qwen OAuth Device Flow, совместимый с `qwen-code`
|
||||||
|
|
@ -39,6 +45,8 @@ Qwen OAuth + OpenAI-compatible endpoint
|
||||||
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
|
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
|
||||||
- live approval flow для инструментов через Telegram
|
- live approval flow для инструментов через Telegram
|
||||||
- provider-based web search с приоритетом DashScope через Qwen OAuth
|
- provider-based web search с приоритетом DashScope через Qwen OAuth
|
||||||
|
- model router с `qwen` как первым провайдером и fallback-ready архитектурой для `gigachat` и `yandexgpt`
|
||||||
|
- router умеет fallback не только по конфигу, но и при runtime-ошибке провайдера
|
||||||
|
|
||||||
## Ограничения текущей реализации
|
## Ограничения текущей реализации
|
||||||
|
|
||||||
|
|
@ -67,6 +75,10 @@ cp serv/.env.example serv/.env
|
||||||
Ключевые параметры сервера:
|
Ключевые параметры сервера:
|
||||||
|
|
||||||
- `NEW_QWEN_STATE_DIR` - где хранить jobs и pending OAuth flows
|
- `NEW_QWEN_STATE_DIR` - где хранить jobs и pending OAuth flows
|
||||||
|
- `NEW_QWEN_DEFAULT_PROVIDER` - основной model provider, сейчас по умолчанию `qwen`
|
||||||
|
- `NEW_QWEN_FALLBACK_PROVIDERS` - fallback-цепочка провайдеров через запятую
|
||||||
|
- `NEW_QWEN_GIGACHAT_MODEL` - имя модели для будущего GigaChat adapter
|
||||||
|
- `NEW_QWEN_YANDEXGPT_MODEL` - имя модели для будущего YandexGPT adapter
|
||||||
- `NEW_QWEN_TOOL_POLICY` - режим инструментов:
|
- `NEW_QWEN_TOOL_POLICY` - режим инструментов:
|
||||||
`full-access` - все инструменты
|
`full-access` - все инструменты
|
||||||
`workspace-write` - без `exec_command`
|
`workspace-write` - без `exec_command`
|
||||||
|
|
@ -130,6 +142,12 @@ curl -X POST http://127.0.0.1:8080/api/v1/auth/device/start
|
||||||
- `POST /api/v1/chat/cancel`
|
- `POST /api/v1/chat/cancel`
|
||||||
- `POST /api/v1/approval/respond`
|
- `POST /api/v1/approval/respond`
|
||||||
|
|
||||||
|
`GET /api/v1/auth/status` теперь также показывает:
|
||||||
|
|
||||||
|
- `default_provider`
|
||||||
|
- `fallback_providers`
|
||||||
|
- список `providers` с availability и capabilities
|
||||||
|
|
||||||
## Telegram Approval Flow
|
## Telegram Approval Flow
|
||||||
|
|
||||||
Если политика инструментов настроена как `ask-shell`, `ask-write` или `ask-all`, бот пришлёт запрос на подтверждение с `approval_id`.
|
Если политика инструментов настроена как `ask-shell`, `ask-write` или `ask-all`, бот пришлёт запрос на подтверждение с `approval_id`.
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,12 @@ def summarize_event(event: dict[str, Any]) -> str | None:
|
||||||
if event_type == "job_status":
|
if event_type == "job_status":
|
||||||
return event.get("message")
|
return event.get("message")
|
||||||
if event_type == "model_request":
|
if event_type == "model_request":
|
||||||
|
provider = event.get("provider")
|
||||||
|
model = event.get("model")
|
||||||
|
if provider and model:
|
||||||
|
return f"Думаю над ответом через {provider}/{model}"
|
||||||
|
if provider:
|
||||||
|
return f"Думаю над ответом через {provider}"
|
||||||
return "Думаю над ответом"
|
return "Думаю над ответом"
|
||||||
if event_type == "tool_call":
|
if event_type == "tool_call":
|
||||||
return f"Вызываю инструмент: {event.get('name')}"
|
return f"Вызываю инструмент: {event.get('name')}"
|
||||||
|
|
@ -555,6 +561,8 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
||||||
chat_id,
|
chat_id,
|
||||||
"Сервер доступен.\n"
|
"Сервер доступен.\n"
|
||||||
f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n"
|
f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n"
|
||||||
|
f"default_provider: {status.get('default_provider')}\n"
|
||||||
|
f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n"
|
||||||
f"resource_url: {status.get('resource_url')}\n"
|
f"resource_url: {status.get('resource_url')}\n"
|
||||||
f"expires_at: {status.get('expires_at')}\n"
|
f"expires_at: {status.get('expires_at')}\n"
|
||||||
f"tool_policy: {status.get('tool_policy')}\n"
|
f"tool_policy: {status.get('tool_policy')}\n"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
NEW_QWEN_HOST=127.0.0.1
|
NEW_QWEN_HOST=127.0.0.1
|
||||||
NEW_QWEN_PORT=8080
|
NEW_QWEN_PORT=8080
|
||||||
NEW_QWEN_MODEL=qwen3.6-plus
|
NEW_QWEN_MODEL=qwen3.6-plus
|
||||||
|
NEW_QWEN_DEFAULT_PROVIDER=qwen
|
||||||
|
NEW_QWEN_FALLBACK_PROVIDERS=
|
||||||
|
NEW_QWEN_GIGACHAT_MODEL=GigaChat
|
||||||
|
NEW_QWEN_YANDEXGPT_MODEL=yandexgpt
|
||||||
NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git
|
NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git
|
||||||
NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions
|
NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions
|
||||||
NEW_QWEN_STATE_DIR=/home/mirivlad/git/new-qwen/.new-qwen/state
|
NEW_QWEN_STATE_DIR=/home/mirivlad/git/new-qwen/.new-qwen/state
|
||||||
|
|
|
||||||
40
serv/app.py
40
serv/app.py
|
|
@ -13,6 +13,7 @@ from config import ServerConfig
|
||||||
from approvals import ApprovalStore
|
from approvals import ApprovalStore
|
||||||
from jobs import JobStore
|
from jobs import JobStore
|
||||||
from llm import QwenAgent
|
from llm import QwenAgent
|
||||||
|
from model_router import ModelProviderError, ModelRouter, build_provider_registry
|
||||||
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
||||||
from sessions import SessionStore
|
from sessions import SessionStore
|
||||||
from tools import ToolRegistry
|
from tools import ToolRegistry
|
||||||
|
|
@ -23,8 +24,10 @@ class AppState:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.oauth = QwenOAuthManager()
|
self.oauth = QwenOAuthManager()
|
||||||
self.sessions = SessionStore(config.session_dir)
|
self.sessions = SessionStore(config.session_dir)
|
||||||
|
self.providers = build_provider_registry(config, self.oauth)
|
||||||
|
self.router = ModelRouter(config, self.providers)
|
||||||
self.tools = ToolRegistry(config, self.oauth)
|
self.tools = ToolRegistry(config, self.oauth)
|
||||||
self.agent = QwenAgent(config, self.oauth, self.tools)
|
self.agent = QwenAgent(config, self.router, self.tools)
|
||||||
self.jobs = JobStore(
|
self.jobs = JobStore(
|
||||||
config.state_dir / "jobs",
|
config.state_dir / "jobs",
|
||||||
retention_seconds=config.jobs_retention_seconds,
|
retention_seconds=config.jobs_retention_seconds,
|
||||||
|
|
@ -103,12 +106,18 @@ class AppState:
|
||||||
if not creds:
|
if not creds:
|
||||||
return {
|
return {
|
||||||
"authenticated": False,
|
"authenticated": False,
|
||||||
|
"default_provider": self.config.default_provider,
|
||||||
|
"fallback_providers": self.config.fallback_providers,
|
||||||
|
"providers": self.providers.statuses(),
|
||||||
"tool_policy": self.config.tool_policy,
|
"tool_policy": self.config.tool_policy,
|
||||||
"pending_flows": len(self.pending_device_flows),
|
"pending_flows": len(self.pending_device_flows),
|
||||||
"pending_approvals": len(self.approvals.list_pending()),
|
"pending_approvals": len(self.approvals.list_pending()),
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
"authenticated": True,
|
"authenticated": True,
|
||||||
|
"default_provider": self.config.default_provider,
|
||||||
|
"fallback_providers": self.config.fallback_providers,
|
||||||
|
"providers": self.providers.statuses(),
|
||||||
"resource_url": creds.get("resource_url"),
|
"resource_url": creds.get("resource_url"),
|
||||||
"expires_at": creds.get("expiry_date"),
|
"expires_at": creds.get("expiry_date"),
|
||||||
"tool_policy": self.config.tool_policy,
|
"tool_policy": self.config.tool_policy,
|
||||||
|
|
@ -154,7 +163,14 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
return
|
return
|
||||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||||
|
|
||||||
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
|
def _run_chat_job(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
preferred_provider: str | None = None,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
if self.app.jobs.is_cancel_requested(job_id):
|
if self.app.jobs.is_cancel_requested(job_id):
|
||||||
reason = "Job canceled before execution started"
|
reason = "Job canceled before execution started"
|
||||||
|
|
@ -181,6 +197,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
arguments,
|
arguments,
|
||||||
),
|
),
|
||||||
is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id),
|
is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id),
|
||||||
|
preferred_provider=preferred_provider,
|
||||||
)
|
)
|
||||||
if self.app.jobs.is_cancel_requested(job_id):
|
if self.app.jobs.is_cancel_requested(job_id):
|
||||||
reason = "Job canceled by operator"
|
reason = "Job canceled by operator"
|
||||||
|
|
@ -209,6 +226,8 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
job_id,
|
job_id,
|
||||||
answer=result["answer"],
|
answer=result["answer"],
|
||||||
usage=result.get("usage"),
|
usage=result.get("usage"),
|
||||||
|
provider=result.get("provider"),
|
||||||
|
model=result.get("model"),
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if self.app.jobs.is_cancel_requested(job_id):
|
if self.app.jobs.is_cancel_requested(job_id):
|
||||||
|
|
@ -299,9 +318,14 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
session_id = body.get("session_id") or uuid.uuid4().hex
|
session_id = body.get("session_id") or uuid.uuid4().hex
|
||||||
user_id = str(body.get("user_id") or "anonymous")
|
user_id = str(body.get("user_id") or "anonymous")
|
||||||
message = body["message"]
|
message = body["message"]
|
||||||
|
preferred_provider = body.get("provider")
|
||||||
session = self.app.sessions.load(session_id)
|
session = self.app.sessions.load(session_id)
|
||||||
history = session.get("messages", [])
|
history = session.get("messages", [])
|
||||||
result = self.app.agent.run(history, message)
|
result = self.app.agent.run(
|
||||||
|
history,
|
||||||
|
message,
|
||||||
|
preferred_provider=preferred_provider,
|
||||||
|
)
|
||||||
persisted_messages = result["messages"][1:]
|
persisted_messages = result["messages"][1:]
|
||||||
self.app.sessions.save(
|
self.app.sessions.save(
|
||||||
session_id,
|
session_id,
|
||||||
|
|
@ -320,6 +344,8 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
"answer": result["answer"],
|
"answer": result["answer"],
|
||||||
"events": result["events"],
|
"events": result["events"],
|
||||||
"usage": result.get("usage"),
|
"usage": result.get("usage"),
|
||||||
|
"provider": result.get("provider"),
|
||||||
|
"model": result.get("model"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
@ -329,10 +355,11 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
session_id = body.get("session_id") or uuid.uuid4().hex
|
session_id = body.get("session_id") or uuid.uuid4().hex
|
||||||
user_id = str(body.get("user_id") or "anonymous")
|
user_id = str(body.get("user_id") or "anonymous")
|
||||||
message = body["message"]
|
message = body["message"]
|
||||||
|
preferred_provider = body.get("provider")
|
||||||
job = self.app.jobs.create(session_id, user_id, message)
|
job = self.app.jobs.create(session_id, user_id, message)
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self._run_chat_job,
|
target=self._run_chat_job,
|
||||||
args=(job["job_id"], session_id, user_id, message),
|
args=(job["job_id"], session_id, user_id, message, preferred_provider),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
@ -342,6 +369,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
"job_id": job["job_id"],
|
"job_id": job["job_id"],
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"status": "queued",
|
"status": "queued",
|
||||||
|
"provider": preferred_provider or self.app.config.default_provider,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
@ -367,6 +395,8 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
"answer": job.get("answer"),
|
"answer": job.get("answer"),
|
||||||
"usage": job.get("usage"),
|
"usage": job.get("usage"),
|
||||||
"error": job.get("error"),
|
"error": job.get("error"),
|
||||||
|
"provider": job.get("provider"),
|
||||||
|
"model": job.get("model"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
@ -427,7 +457,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||||
except OAuthError as exc:
|
except (OAuthError, ModelProviderError) as exc:
|
||||||
self._send(HTTPStatus.BAD_GATEWAY, {"error": str(exc)})
|
self._send(HTTPStatus.BAD_GATEWAY, {"error": str(exc)})
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
self._send(HTTPStatus.BAD_REQUEST, {"error": f"Missing field: {exc}"})
|
self._send(HTTPStatus.BAD_REQUEST, {"error": f"Missing field: {exc}"})
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,10 @@ class ServerConfig:
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
model: str
|
model: str
|
||||||
|
default_provider: str
|
||||||
|
fallback_providers: list[str]
|
||||||
|
gigachat_model: str
|
||||||
|
yandexgpt_model: str
|
||||||
workspace_root: Path
|
workspace_root: Path
|
||||||
session_dir: Path
|
session_dir: Path
|
||||||
state_dir: Path
|
state_dir: Path
|
||||||
|
|
@ -59,6 +63,14 @@ class ServerConfig:
|
||||||
host=os.environ.get("NEW_QWEN_HOST", "127.0.0.1"),
|
host=os.environ.get("NEW_QWEN_HOST", "127.0.0.1"),
|
||||||
port=int(os.environ.get("NEW_QWEN_PORT", "8080")),
|
port=int(os.environ.get("NEW_QWEN_PORT", "8080")),
|
||||||
model=os.environ.get("NEW_QWEN_MODEL", "qwen3.6-plus"),
|
model=os.environ.get("NEW_QWEN_MODEL", "qwen3.6-plus"),
|
||||||
|
default_provider=os.environ.get("NEW_QWEN_DEFAULT_PROVIDER", "qwen").strip() or "qwen",
|
||||||
|
fallback_providers=[
|
||||||
|
item.strip()
|
||||||
|
for item in os.environ.get("NEW_QWEN_FALLBACK_PROVIDERS", "").split(",")
|
||||||
|
if item.strip()
|
||||||
|
],
|
||||||
|
gigachat_model=os.environ.get("NEW_QWEN_GIGACHAT_MODEL", "GigaChat").strip(),
|
||||||
|
yandexgpt_model=os.environ.get("NEW_QWEN_YANDEXGPT_MODEL", "yandexgpt").strip(),
|
||||||
workspace_root=workspace_root.resolve(),
|
workspace_root=workspace_root.resolve(),
|
||||||
session_dir=session_dir.resolve(),
|
session_dir=session_dir.resolve(),
|
||||||
state_dir=state_dir.resolve(),
|
state_dir=state_dir.resolve(),
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,8 @@ class JobStore:
|
||||||
"events": [],
|
"events": [],
|
||||||
"answer": None,
|
"answer": None,
|
||||||
"usage": None,
|
"usage": None,
|
||||||
|
"provider": None,
|
||||||
|
"model": None,
|
||||||
"error": None,
|
"error": None,
|
||||||
"cancel_requested": False,
|
"cancel_requested": False,
|
||||||
"cancel_actor": None,
|
"cancel_actor": None,
|
||||||
|
|
@ -141,12 +143,16 @@ class JobStore:
|
||||||
*,
|
*,
|
||||||
answer: str,
|
answer: str,
|
||||||
usage: dict[str, Any] | None,
|
usage: dict[str, Any] | None,
|
||||||
|
provider: str | None = None,
|
||||||
|
model: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
job = self._jobs[job_id]
|
job = self._jobs[job_id]
|
||||||
job["status"] = "completed"
|
job["status"] = "completed"
|
||||||
job["answer"] = answer
|
job["answer"] = answer
|
||||||
job["usage"] = usage
|
job["usage"] = usage
|
||||||
|
job["provider"] = provider
|
||||||
|
job["model"] = model
|
||||||
job["updated_at"] = time.time()
|
job["updated_at"] = time.time()
|
||||||
self._save_job(job)
|
self._save_job(job)
|
||||||
|
|
||||||
|
|
|
||||||
64
serv/llm.py
64
serv/llm.py
|
|
@ -2,10 +2,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from urllib import error, request
|
|
||||||
|
|
||||||
from config import ServerConfig
|
from config import ServerConfig
|
||||||
from oauth import OAuthError, QwenOAuthManager
|
from model_router import CompletionRequest, ModelRouter
|
||||||
from tools import ToolError, ToolRegistry
|
from tools import ToolError, ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,37 +17,11 @@ Do not claim to have executed tools unless a tool result confirms it."""
|
||||||
|
|
||||||
|
|
||||||
class QwenAgent:
|
class QwenAgent:
|
||||||
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager, tools: ToolRegistry) -> None:
|
def __init__(self, config: ServerConfig, router: ModelRouter, tools: ToolRegistry) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.oauth = oauth
|
self.router = router
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def _request_completion(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
||||||
creds = self.oauth.get_valid_credentials()
|
|
||||||
base_url = self.oauth.get_openai_base_url(creds)
|
|
||||||
payload = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"tools": self.tools.schemas(),
|
|
||||||
"tool_choice": "auto",
|
|
||||||
}
|
|
||||||
data = json.dumps(payload).encode("utf-8")
|
|
||||||
req = request.Request(
|
|
||||||
f"{base_url}/chat/completions",
|
|
||||||
data=data,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {creds['access_token']}",
|
|
||||||
},
|
|
||||||
method="POST",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with request.urlopen(req, timeout=180) as response:
|
|
||||||
return json.loads(response.read().decode("utf-8"))
|
|
||||||
except error.HTTPError as exc:
|
|
||||||
body = exc.read().decode("utf-8", errors="replace")
|
|
||||||
raise OAuthError(f"LLM request failed with HTTP {exc.code}: {body}") from exc
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
history: list[dict[str, Any]],
|
history: list[dict[str, Any]],
|
||||||
|
|
@ -56,6 +29,7 @@ class QwenAgent:
|
||||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
on_event: Callable[[dict[str, Any]], None] | None = None,
|
||||||
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
|
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
|
||||||
is_cancelled: Callable[[], bool] | None = None,
|
is_cancelled: Callable[[], bool] | None = None,
|
||||||
|
preferred_provider: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
emit = on_event or (lambda _event: None)
|
emit = on_event or (lambda _event: None)
|
||||||
cancel_check = is_cancelled or (lambda: False)
|
cancel_check = is_cancelled or (lambda: False)
|
||||||
|
|
@ -69,11 +43,33 @@ class QwenAgent:
|
||||||
messages.extend(history)
|
messages.extend(history)
|
||||||
messages.append({"role": "user", "content": user_message})
|
messages.append({"role": "user", "content": user_message})
|
||||||
events: list[dict[str, Any]] = []
|
events: list[dict[str, Any]] = []
|
||||||
|
selected_provider: str | None = None
|
||||||
|
selected_model: str | None = None
|
||||||
|
|
||||||
for _ in range(self.config.max_tool_rounds):
|
for _ in range(self.config.max_tool_rounds):
|
||||||
ensure_not_cancelled()
|
ensure_not_cancelled()
|
||||||
emit({"type": "model_request", "message": "Запрашиваю ответ модели"})
|
completion = self.router.complete(
|
||||||
response = self._request_completion(messages)
|
CompletionRequest(
|
||||||
|
messages=messages,
|
||||||
|
tools=self.tools.schemas(),
|
||||||
|
tool_choice="auto",
|
||||||
|
preferred_provider=preferred_provider,
|
||||||
|
require_tools=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
{
|
||||||
|
"type": "model_request",
|
||||||
|
"message": "Запрашиваю ответ модели",
|
||||||
|
"provider": completion.provider_name,
|
||||||
|
"model": completion.model_name,
|
||||||
|
"selection_reason": completion.selection_reason,
|
||||||
|
"attempted": completion.attempted,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
selected_provider = completion.provider_name
|
||||||
|
selected_model = completion.model_name
|
||||||
|
response = completion.payload
|
||||||
ensure_not_cancelled()
|
ensure_not_cancelled()
|
||||||
choice = response["choices"][0]["message"]
|
choice = response["choices"][0]["message"]
|
||||||
tool_calls = choice.get("tool_calls") or []
|
tool_calls = choice.get("tool_calls") or []
|
||||||
|
|
@ -88,6 +84,8 @@ class QwenAgent:
|
||||||
"answer": content or "",
|
"answer": content or "",
|
||||||
"events": events,
|
"events": events,
|
||||||
"usage": response.get("usage"),
|
"usage": response.get("usage"),
|
||||||
|
"provider": selected_provider,
|
||||||
|
"model": selected_model,
|
||||||
"messages": messages + [{"role": "assistant", "content": content or ""}],
|
"messages": messages + [{"role": "assistant", "content": content or ""}],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -164,5 +162,7 @@ class QwenAgent:
|
||||||
"answer": final_message,
|
"answer": final_message,
|
||||||
"events": events,
|
"events": events,
|
||||||
"usage": None,
|
"usage": None,
|
||||||
|
"provider": selected_provider,
|
||||||
|
"model": selected_model,
|
||||||
"messages": messages + [{"role": "assistant", "content": final_message}],
|
"messages": messages + [{"role": "assistant", "content": final_message}],
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from urllib import error, request
|
||||||
|
|
||||||
|
from config import ServerConfig
|
||||||
|
from oauth import OAuthError, QwenOAuthManager
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ProviderCapabilities:
|
||||||
|
tool_calling: bool = False
|
||||||
|
web_search: bool = False
|
||||||
|
oauth_auth: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CompletionRequest:
|
||||||
|
messages: list[dict[str, Any]]
|
||||||
|
tools: list[dict[str, Any]]
|
||||||
|
tool_choice: str = "auto"
|
||||||
|
preferred_provider: str | None = None
|
||||||
|
require_tools: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CompletionResponse:
|
||||||
|
provider_name: str
|
||||||
|
model_name: str
|
||||||
|
payload: dict[str, Any]
|
||||||
|
selection_reason: str
|
||||||
|
attempted: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelProvider:
|
||||||
|
name = "base"
|
||||||
|
capabilities = ProviderCapabilities()
|
||||||
|
|
||||||
|
def __init__(self, config: ServerConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def unavailable_reason(self) -> str | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def model_name(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class UnavailableModelProvider(BaseModelProvider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ServerConfig,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
model_name: str,
|
||||||
|
reason: str,
|
||||||
|
capabilities: ProviderCapabilities | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.name = name
|
||||||
|
self._model_name = model_name
|
||||||
|
self._reason = reason
|
||||||
|
if capabilities is not None:
|
||||||
|
self.capabilities = capabilities
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unavailable_reason(self) -> str | None:
|
||||||
|
return self._reason
|
||||||
|
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||||
|
raise ModelProviderError(f"Provider {self.name} is unavailable: {self._reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModelProvider(BaseModelProvider):
|
||||||
|
name = "qwen"
|
||||||
|
capabilities = ProviderCapabilities(
|
||||||
|
tool_calling=True,
|
||||||
|
web_search=True,
|
||||||
|
oauth_auth=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.oauth = oauth
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
creds = self.oauth.load_credentials()
|
||||||
|
return bool(creds and creds.get("access_token"))
|
||||||
|
|
||||||
|
def unavailable_reason(self) -> str | None:
|
||||||
|
if self.is_available():
|
||||||
|
return None
|
||||||
|
return "Qwen OAuth is not configured"
|
||||||
|
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self.config.model
|
||||||
|
|
||||||
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||||
|
creds = self.oauth.get_valid_credentials()
|
||||||
|
base_url = self.oauth.get_openai_base_url(creds)
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name(),
|
||||||
|
"messages": completion_request.messages,
|
||||||
|
"tools": completion_request.tools,
|
||||||
|
"tool_choice": completion_request.tool_choice,
|
||||||
|
}
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
req = request.Request(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
data=data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {creds['access_token']}",
|
||||||
|
},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with request.urlopen(req, timeout=180) as response:
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
except error.HTTPError as exc:
|
||||||
|
body = exc.read().decode("utf-8", errors="replace")
|
||||||
|
raise ModelProviderError(
|
||||||
|
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
|
||||||
|
) from exc
|
||||||
|
except OAuthError as exc:
|
||||||
|
raise ModelProviderError(str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderRegistry:
|
||||||
|
def __init__(self, providers: list[BaseModelProvider]) -> None:
|
||||||
|
self._providers = {provider.name: provider for provider in providers}
|
||||||
|
|
||||||
|
def get(self, name: str) -> BaseModelProvider | None:
|
||||||
|
return self._providers.get(name)
|
||||||
|
|
||||||
|
def list_names(self) -> list[str]:
|
||||||
|
return sorted(self._providers.keys())
|
||||||
|
|
||||||
|
def statuses(self) -> list[dict[str, Any]]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": provider.name,
|
||||||
|
"model": provider.model_name(),
|
||||||
|
"available": provider.is_available(),
|
||||||
|
"reason": provider.unavailable_reason(),
|
||||||
|
"capabilities": {
|
||||||
|
"tool_calling": provider.capabilities.tool_calling,
|
||||||
|
"web_search": provider.capabilities.web_search,
|
||||||
|
"oauth_auth": provider.capabilities.oauth_auth,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for provider in self._providers.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRouter:
|
||||||
|
def __init__(self, config: ServerConfig, registry: ProviderRegistry) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def _candidate_names(self, preferred_provider: str | None) -> list[str]:
|
||||||
|
names: list[str] = []
|
||||||
|
for name in [preferred_provider, self.config.default_provider, *self.config.fallback_providers]:
|
||||||
|
if name and name not in names:
|
||||||
|
names.append(name)
|
||||||
|
for name in self.registry.list_names():
|
||||||
|
if name not in names:
|
||||||
|
names.append(name)
|
||||||
|
return names
|
||||||
|
|
||||||
|
def _supports_request(
|
||||||
|
self,
|
||||||
|
provider: BaseModelProvider,
|
||||||
|
completion_request: CompletionRequest,
|
||||||
|
) -> bool:
|
||||||
|
if completion_request.require_tools and not provider.capabilities.tool_calling:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def complete(self, completion_request: CompletionRequest) -> CompletionResponse:
|
||||||
|
attempted: list[str] = []
|
||||||
|
reasons: list[str] = []
|
||||||
|
for name in self._candidate_names(completion_request.preferred_provider):
|
||||||
|
provider = self.registry.get(name)
|
||||||
|
if not provider:
|
||||||
|
reasons.append(f"{name}: unknown provider")
|
||||||
|
continue
|
||||||
|
attempted.append(name)
|
||||||
|
if not provider.is_available():
|
||||||
|
reason = provider.unavailable_reason() or "provider unavailable"
|
||||||
|
reasons.append(f"{name}: {reason}")
|
||||||
|
continue
|
||||||
|
if not self._supports_request(provider, completion_request):
|
||||||
|
reasons.append(f"{name}: missing required capabilities")
|
||||||
|
continue
|
||||||
|
selection_reason = "selected preferred provider"
|
||||||
|
if completion_request.preferred_provider and name != completion_request.preferred_provider:
|
||||||
|
selection_reason = f"preferred provider unavailable, fell back to {name}"
|
||||||
|
elif not completion_request.preferred_provider and name != self.config.default_provider:
|
||||||
|
selection_reason = f"default provider unavailable, fell back to {name}"
|
||||||
|
elif name == self.config.default_provider:
|
||||||
|
selection_reason = "selected default provider"
|
||||||
|
try:
|
||||||
|
payload = provider.complete(completion_request)
|
||||||
|
except ModelProviderError as exc:
|
||||||
|
reasons.append(f"{name}: request failed: {exc}")
|
||||||
|
continue
|
||||||
|
return CompletionResponse(
|
||||||
|
provider_name=name,
|
||||||
|
model_name=provider.model_name(),
|
||||||
|
payload=payload,
|
||||||
|
selection_reason=selection_reason,
|
||||||
|
attempted=attempted,
|
||||||
|
)
|
||||||
|
details = "; ".join(reasons) if reasons else "no providers registered"
|
||||||
|
raise ModelProviderError(f"No model provider available: {details}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
|
||||||
|
providers: list[BaseModelProvider] = [
|
||||||
|
QwenModelProvider(config, oauth),
|
||||||
|
UnavailableModelProvider(
|
||||||
|
config,
|
||||||
|
name="gigachat",
|
||||||
|
model_name=config.gigachat_model,
|
||||||
|
reason="GigaChat provider is not implemented yet",
|
||||||
|
capabilities=ProviderCapabilities(),
|
||||||
|
),
|
||||||
|
UnavailableModelProvider(
|
||||||
|
config,
|
||||||
|
name="yandexgpt",
|
||||||
|
model_name=config.yandexgpt_model,
|
||||||
|
reason="YandexGPT provider is not implemented yet",
|
||||||
|
capabilities=ProviderCapabilities(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return ProviderRegistry(providers)
|
||||||
Loading…
Reference in New Issue