Add GigaChat model provider
This commit is contained in:
parent
7fac6fa41e
commit
db89c14b37
10
README.md
10
README.md
|
|
@ -27,6 +27,7 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
- `model_router.py` - выбор провайдера и fallback policy
|
||||
- `llm.py` - агентный цикл поверх абстракции провайдера
|
||||
- `oauth.py` - auth only для Qwen path
|
||||
- `gigachat.py` - token management для GigaChat
|
||||
|
||||
## Что уже реализовано
|
||||
|
||||
|
|
@ -47,6 +48,7 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
- provider-based web search с приоритетом DashScope через Qwen OAuth
|
||||
- model router с `qwen` как первым провайдером и fallback-ready архитектурой для `gigachat` и `yandexgpt`
|
||||
- router умеет fallback не только по конфигу, но и при runtime-ошибке провайдера
|
||||
- реальный adapter для `gigachat` с token fetch и нормализацией function calling под внутренний agent loop
|
||||
|
||||
## Ограничения текущей реализации
|
||||
|
||||
|
|
@ -77,7 +79,11 @@ cp serv/.env.example serv/.env
|
|||
- `NEW_QWEN_STATE_DIR` - где хранить jobs и pending OAuth flows
|
||||
- `NEW_QWEN_DEFAULT_PROVIDER` - основной model provider, сейчас по умолчанию `qwen`
|
||||
- `NEW_QWEN_FALLBACK_PROVIDERS` - fallback-цепочка провайдеров через запятую
|
||||
- `NEW_QWEN_GIGACHAT_MODEL` - имя модели для будущего GigaChat adapter
|
||||
- `NEW_QWEN_GIGACHAT_MODEL` - имя модели GigaChat
|
||||
- `NEW_QWEN_GIGACHAT_AUTH_KEY` - ключ авторизации GigaChat для `Authorization: Basic ...`
|
||||
- `NEW_QWEN_GIGACHAT_SCOPE` - scope для получения access token, по умолчанию `GIGACHAT_API_PERS`
|
||||
- `NEW_QWEN_GIGACHAT_API_BASE_URL` - базовый URL inference API GigaChat
|
||||
- `NEW_QWEN_GIGACHAT_OAUTH_URL` - URL получения access token GigaChat
|
||||
- `NEW_QWEN_YANDEXGPT_MODEL` - имя модели для будущего YandexGPT adapter
|
||||
- `NEW_QWEN_TOOL_POLICY` - режим инструментов:
|
||||
`full-access` - все инструменты
|
||||
|
|
@ -144,6 +150,8 @@ curl -X POST http://127.0.0.1:8080/api/v1/auth/device/start
|
|||
|
||||
`GET /api/v1/auth/status` теперь также показывает:
|
||||
|
||||
- `ready`
|
||||
- `available_providers`
|
||||
- `default_provider`
|
||||
- `fallback_providers`
|
||||
- список `providers` с availability и capabilities
|
||||
|
|
|
|||
14
bot/app.py
14
bot/app.py
|
|
@ -153,8 +153,18 @@ def start_auth_flow(
|
|||
|
||||
def ensure_auth(api: TelegramAPI, config: BotConfig, state: dict[str, Any], chat_id: int) -> bool:
|
||||
status = get_json(f"{config.server_url}/api/v1/auth/status")
|
||||
if status.get("authenticated"):
|
||||
if status.get("ready") or status.get("available_providers"):
|
||||
return True
|
||||
default_provider = status.get("default_provider")
|
||||
fallback_providers = status.get("fallback_providers") or []
|
||||
if default_provider != "qwen" and "qwen" not in fallback_providers:
|
||||
api.send_message(
|
||||
chat_id,
|
||||
"На сервере нет доступных model provider-ов. "
|
||||
f"Текущий default_provider: {default_provider}. "
|
||||
"Для GigaChat/YandexGPT нужно настроить серверные credentials.",
|
||||
)
|
||||
return False
|
||||
start_auth_flow(api, config, state, chat_id)
|
||||
return False
|
||||
|
||||
|
|
@ -561,6 +571,8 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
|||
chat_id,
|
||||
"Сервер доступен.\n"
|
||||
f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n"
|
||||
f"ready: {status.get('ready')}\n"
|
||||
f"available_providers: {', '.join(status.get('available_providers') or []) or '-'}\n"
|
||||
f"default_provider: {status.get('default_provider')}\n"
|
||||
f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n"
|
||||
f"resource_url: {status.get('resource_url')}\n"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@ NEW_QWEN_MODEL=qwen3.6-plus
|
|||
NEW_QWEN_DEFAULT_PROVIDER=qwen
|
||||
NEW_QWEN_FALLBACK_PROVIDERS=
|
||||
NEW_QWEN_GIGACHAT_MODEL=GigaChat
|
||||
NEW_QWEN_GIGACHAT_AUTH_KEY=
|
||||
NEW_QWEN_GIGACHAT_SCOPE=GIGACHAT_API_PERS
|
||||
NEW_QWEN_GIGACHAT_API_BASE_URL=https://gigachat.devices.sberbank.ru/api/v1
|
||||
NEW_QWEN_GIGACHAT_OAUTH_URL=https://ngw.devices.sberbank.ru:9443/api/v2/oauth
|
||||
NEW_QWEN_YANDEXGPT_MODEL=yandexgpt
|
||||
NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git
|
||||
NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions
|
||||
|
|
|
|||
14
serv/app.py
14
serv/app.py
|
|
@ -103,21 +103,31 @@ class AppState:
|
|||
|
||||
def auth_status(self) -> dict[str, Any]:
|
||||
creds = self.oauth.load_credentials()
|
||||
providers = self.providers.statuses()
|
||||
available_providers = [
|
||||
item["name"]
|
||||
for item in providers
|
||||
if item.get("available")
|
||||
]
|
||||
if not creds:
|
||||
return {
|
||||
"authenticated": False,
|
||||
"ready": bool(available_providers),
|
||||
"available_providers": available_providers,
|
||||
"default_provider": self.config.default_provider,
|
||||
"fallback_providers": self.config.fallback_providers,
|
||||
"providers": self.providers.statuses(),
|
||||
"providers": providers,
|
||||
"tool_policy": self.config.tool_policy,
|
||||
"pending_flows": len(self.pending_device_flows),
|
||||
"pending_approvals": len(self.approvals.list_pending()),
|
||||
}
|
||||
return {
|
||||
"authenticated": True,
|
||||
"ready": bool(available_providers),
|
||||
"available_providers": available_providers,
|
||||
"default_provider": self.config.default_provider,
|
||||
"fallback_providers": self.config.fallback_providers,
|
||||
"providers": self.providers.statuses(),
|
||||
"providers": providers,
|
||||
"resource_url": creds.get("resource_url"),
|
||||
"expires_at": creds.get("expiry_date"),
|
||||
"tool_policy": self.config.tool_policy,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ class ServerConfig:
|
|||
default_provider: str
|
||||
fallback_providers: list[str]
|
||||
gigachat_model: str
|
||||
gigachat_auth_key: str
|
||||
gigachat_scope: str
|
||||
gigachat_api_base_url: str
|
||||
gigachat_oauth_url: str
|
||||
yandexgpt_model: str
|
||||
workspace_root: Path
|
||||
session_dir: Path
|
||||
|
|
@ -70,6 +74,16 @@ class ServerConfig:
|
|||
if item.strip()
|
||||
],
|
||||
gigachat_model=os.environ.get("NEW_QWEN_GIGACHAT_MODEL", "GigaChat").strip(),
|
||||
gigachat_auth_key=os.environ.get("NEW_QWEN_GIGACHAT_AUTH_KEY", "").strip(),
|
||||
gigachat_scope=os.environ.get("NEW_QWEN_GIGACHAT_SCOPE", "GIGACHAT_API_PERS").strip(),
|
||||
gigachat_api_base_url=os.environ.get(
|
||||
"NEW_QWEN_GIGACHAT_API_BASE_URL",
|
||||
"https://gigachat.devices.sberbank.ru/api/v1",
|
||||
).strip(),
|
||||
gigachat_oauth_url=os.environ.get(
|
||||
"NEW_QWEN_GIGACHAT_OAUTH_URL",
|
||||
"https://ngw.devices.sberbank.ru:9443/api/v2/oauth",
|
||||
).strip(),
|
||||
yandexgpt_model=os.environ.get("NEW_QWEN_YANDEXGPT_MODEL", "yandexgpt").strip(),
|
||||
workspace_root=workspace_root.resolve(),
|
||||
session_dir=session_dir.resolve(),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib import error, parse, request
|
||||
|
||||
from config import ServerConfig
|
||||
|
||||
|
||||
class GigaChatError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class GigaChatAuthManager:
|
||||
def __init__(self, config: ServerConfig) -> None:
|
||||
self.config = config
|
||||
self.token_path = config.state_dir / "gigachat_token.json"
|
||||
self.token_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.config.gigachat_auth_key)
|
||||
|
||||
def _authorization_header(self) -> str:
|
||||
raw = self.config.gigachat_auth_key.strip()
|
||||
if not raw:
|
||||
raise GigaChatError("GigaChat auth key is not configured")
|
||||
if raw.lower().startswith("basic "):
|
||||
return raw
|
||||
return f"Basic {raw}"
|
||||
|
||||
def load_token(self) -> dict[str, Any] | None:
|
||||
if not self.token_path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(self.token_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def save_token(self, payload: dict[str, Any]) -> None:
|
||||
self.token_path.write_text(
|
||||
json.dumps(payload, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def fetch_token(self) -> dict[str, Any]:
|
||||
data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8")
|
||||
req = request.Request(
|
||||
self.config.gigachat_oauth_url,
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
"RqUID": str(uuid.uuid4()),
|
||||
"Authorization": self._authorization_header(),
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=60) as response:
|
||||
payload = json.loads(response.read().decode("utf-8"))
|
||||
except error.HTTPError as exc:
|
||||
body = exc.read().decode("utf-8", errors="replace")
|
||||
raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc
|
||||
token = {
|
||||
"access_token": payload["access_token"],
|
||||
"expires_at": int(payload["expires_at"]),
|
||||
}
|
||||
self.save_token(token)
|
||||
return token
|
||||
|
||||
def get_valid_token(self) -> str:
|
||||
if not self.is_configured():
|
||||
raise GigaChatError("GigaChat auth key is not configured")
|
||||
token = self.load_token()
|
||||
now = int(time.time())
|
||||
if token and int(token.get("expires_at", 0)) - now > 30:
|
||||
return str(token["access_token"])
|
||||
refreshed = self.fetch_token()
|
||||
return str(refreshed["access_token"])
|
||||
|
|
@ -94,6 +94,7 @@ class QwenAgent:
|
|||
"role": "assistant",
|
||||
"content": content or "",
|
||||
"tool_calls": tool_calls,
|
||||
**({"functions_state_id": choice["functions_state_id"]} if choice.get("functions_state_id") else {}),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -129,6 +130,7 @@ class QwenAgent:
|
|||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call["id"],
|
||||
"name": tool_name,
|
||||
"content": self.tools.encode_result(result),
|
||||
}
|
||||
)
|
||||
|
|
@ -148,6 +150,7 @@ class QwenAgent:
|
|||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call["id"],
|
||||
"name": tool_name,
|
||||
"content": self.tools.encode_result(result),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ from __future__ import annotations
|
|||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import uuid
|
||||
from urllib import error, request
|
||||
|
||||
from config import ServerConfig
|
||||
from gigachat import GigaChatAuthManager, GigaChatError
|
||||
from oauth import OAuthError, QwenOAuthManager
|
||||
|
||||
|
||||
|
|
@ -143,6 +145,148 @@ class QwenModelProvider(BaseModelProvider):
|
|||
raise ModelProviderError(str(exc)) from exc
|
||||
|
||||
|
||||
class GigaChatModelProvider(BaseModelProvider):
|
||||
name = "gigachat"
|
||||
capabilities = ProviderCapabilities(
|
||||
tool_calling=True,
|
||||
web_search=False,
|
||||
oauth_auth=False,
|
||||
)
|
||||
|
||||
def __init__(self, config: ServerConfig, auth: GigaChatAuthManager) -> None:
|
||||
super().__init__(config)
|
||||
self.auth = auth
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.auth.is_configured()
|
||||
|
||||
def unavailable_reason(self) -> str | None:
|
||||
if self.is_available():
|
||||
return None
|
||||
return "GigaChat auth key is not configured"
|
||||
|
||||
def model_name(self) -> str:
|
||||
return self.config.gigachat_model
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(tool.get("function") or {})
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
converted: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
if role in {"system", "user"}:
|
||||
converted.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": message.get("content", ""),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if role == "assistant":
|
||||
payload: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": message.get("content", ""),
|
||||
}
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
first_call = tool_calls[0]
|
||||
raw_arguments = first_call.get("function", {}).get("arguments", "{}")
|
||||
if isinstance(raw_arguments, str):
|
||||
try:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"raw": raw_arguments}
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
payload["function_call"] = {
|
||||
"name": first_call.get("function", {}).get("name"),
|
||||
"arguments": arguments,
|
||||
}
|
||||
if message.get("functions_state_id"):
|
||||
payload["functions_state_id"] = message["functions_state_id"]
|
||||
converted.append(payload)
|
||||
continue
|
||||
if role == "tool":
|
||||
converted.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": message.get("name") or "tool_result",
|
||||
"content": message.get("content", ""),
|
||||
}
|
||||
)
|
||||
return converted
|
||||
|
||||
def _normalize_response(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
choices = payload.get("choices") or []
|
||||
if not choices:
|
||||
return payload
|
||||
choice = choices[0]
|
||||
message = dict(choice.get("message") or {})
|
||||
normalized_message: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": message.get("content", "") or "",
|
||||
}
|
||||
function_call = message.get("function_call")
|
||||
if function_call:
|
||||
arguments = function_call.get("arguments", {})
|
||||
if not isinstance(arguments, str):
|
||||
arguments = json.dumps(arguments, ensure_ascii=False)
|
||||
normalized_message["tool_calls"] = [
|
||||
{
|
||||
"id": uuid.uuid4().hex,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_call.get("name"),
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
]
|
||||
if message.get("functions_state_id"):
|
||||
normalized_message["functions_state_id"] = message.get("functions_state_id")
|
||||
choice["message"] = normalized_message
|
||||
payload["choices"] = choices
|
||||
return payload
|
||||
|
||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||
try:
|
||||
access_token = self.auth.get_valid_token()
|
||||
except GigaChatError as exc:
|
||||
raise ModelProviderError(str(exc)) from exc
|
||||
api_base = self.config.gigachat_api_base_url.rstrip("/")
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model_name(),
|
||||
"messages": self._convert_messages(completion_request.messages),
|
||||
}
|
||||
if completion_request.tools:
|
||||
payload["functions"] = [
|
||||
self._convert_tool_schema(tool)
|
||||
for tool in completion_request.tools
|
||||
]
|
||||
payload["function_call"] = "auto"
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
req = request.Request(
|
||||
f"{api_base}/chat/completions",
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=180) as response:
|
||||
raw = json.loads(response.read().decode("utf-8"))
|
||||
except error.HTTPError as exc:
|
||||
body = exc.read().decode("utf-8", errors="replace")
|
||||
raise ModelProviderError(
|
||||
f"Provider {self.name} request failed with HTTP {exc.code}: {body}"
|
||||
) from exc
|
||||
return self._normalize_response(raw)
|
||||
|
||||
|
||||
class ProviderRegistry:
|
||||
def __init__(self, providers: list[BaseModelProvider]) -> None:
|
||||
self._providers = {provider.name: provider for provider in providers}
|
||||
|
|
@ -236,13 +380,7 @@ class ModelRouter:
|
|||
def build_provider_registry(config: ServerConfig, oauth: QwenOAuthManager) -> ProviderRegistry:
|
||||
providers: list[BaseModelProvider] = [
|
||||
QwenModelProvider(config, oauth),
|
||||
UnavailableModelProvider(
|
||||
config,
|
||||
name="gigachat",
|
||||
model_name=config.gigachat_model,
|
||||
reason="GigaChat provider is not implemented yet",
|
||||
capabilities=ProviderCapabilities(),
|
||||
),
|
||||
GigaChatModelProvider(config, GigaChatAuthManager(config)),
|
||||
UnavailableModelProvider(
|
||||
config,
|
||||
name="yandexgpt",
|
||||
|
|
|
|||
Loading…
Reference in New Issue