diff --git a/bot/app.py b/bot/app.py index 392b66d..97561d3 100644 --- a/bot/app.py +++ b/bot/app.py @@ -14,6 +14,16 @@ from telegram_api import TelegramAPI STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-state.json" TYPING_INTERVAL_SECONDS = 4.0 +BOT_COMMANDS = [ + {"command": "help", "description": "Список команд"}, + {"command": "status", "description": "Статус сервера и чата"}, + {"command": "auth", "description": "Запустить Qwen OAuth"}, + {"command": "provider", "description": "Выбрать провайдера"}, + {"command": "model", "description": "Выбрать модель"}, + {"command": "session", "description": "Показать сессию"}, + {"command": "clear", "description": "Очистить контекст"}, + {"command": "cancel", "description": "Отменить активный job"}, +] def load_state() -> dict[str, Any]: @@ -26,6 +36,7 @@ def load_state() -> dict[str, Any]: "chat_active_jobs": {}, "chat_queues": {}, "pending_approvals": {}, + "chat_preferences": {}, } state = json.loads(STATE_FILE.read_text(encoding="utf-8")) state.setdefault("sessions", {}) @@ -34,6 +45,7 @@ def load_state() -> dict[str, Any]: state.setdefault("chat_active_jobs", {}) state.setdefault("chat_queues", {}) state.setdefault("pending_approvals", {}) + state.setdefault("chat_preferences", {}) return state @@ -66,6 +78,49 @@ def send_text_chunks(api: TelegramAPI, chat_id: int, text: str) -> None: api.send_message(chat_id, normalized[start : start + chunk_size]) +def ensure_bot_commands(api: TelegramAPI, state: dict[str, Any]) -> None: + api.set_my_commands(BOT_COMMANDS) + + +def get_provider_catalog(config: BotConfig) -> dict[str, Any]: + return get_json(f"{config.server_url}/api/v1/provider-catalog") + + +def get_chat_preferences(state: dict[str, Any], chat_id: int) -> dict[str, Any]: + prefs = state.setdefault("chat_preferences", {}).setdefault(str(chat_id), {}) + prefs.setdefault("provider", None) + prefs.setdefault("model", None) + return prefs + + +def get_selected_provider_and_model( + state: dict[str, Any], + chat_id: int, + catalog: dict[str, Any], +) -> tuple[str | None, str | None]: + prefs = get_chat_preferences(state, chat_id) + provider = prefs.get("provider") or catalog.get("default_provider") + model = prefs.get("model") + providers = {item.get("name"): item for item in catalog.get("providers", [])} + provider_info = providers.get(provider or "") + available_models = provider_info.get("models", []) if provider_info else [] + if not model and available_models: + model = available_models[0].get("id") + return provider, model + + +def format_status_text(job_state: dict[str, Any], text: str) -> str: + del job_state + return (text or "Пустой ответ.")[:4000] + + +def format_final_signature(job_state: dict[str, Any]) -> str: + model = job_state.get("model") if isinstance(job_state.get("model"), str) else None + if not model: + return "" + return f"\n\n{html.escape(model)}" + + def render_markdownish_html(text: str) -> str: normalized = (text or "Пустой ответ.").replace("\r\n", "\n") fence_pattern = re.compile(r"```(?:[^\n`]*)\n(.*?)```", re.DOTALL) @@ -91,7 +146,7 @@ def update_status_message( job_state: dict[str, Any], text: str, ) -> None: - normalized = (text or "Пустой ответ.")[:4000] + normalized = format_status_text(job_state, text) chat_id = int(job_state["chat_id"]) message_id = int(job_state.get("status_message_id") or 0) if message_id: @@ -100,6 +155,7 @@ def update_status_message( message_id = api.send_message(chat_id, normalized) job_state["status_message_id"] = message_id job_state["status_message_text"] = normalized + job_state["status_body_text"] = text or "Пустой ответ." def set_final_message( @@ -109,8 +165,9 @@ def set_final_message( ) -> None: normalized = text or "Пустой ответ." chunk_size = 3800 + signature = format_final_signature(job_state) first_chunk = normalized[:chunk_size] - rendered = render_markdownish_html(first_chunk) + rendered = render_markdownish_html(first_chunk) + signature chat_id = int(job_state["chat_id"]) message_id = int(job_state.get("status_message_id") or 0) if message_id: @@ -145,30 +202,38 @@ def keep_job_typing( def summarize_event(event: dict[str, Any]) -> str | None: event_type = event.get("type") if event_type == "job_status": - return event.get("message") + message = str(event.get("message") or "").strip() + if message == "Запрос принят сервером": + return "Смотрю, что можно сделать." + if message == "Ответ готов": + return "Формулирую ответ." + return message or None if event_type == "model_request": - provider = event.get("provider") - model = event.get("model") - if provider and model: - return f"Думаю над ответом через {provider}/{model}" - if provider: - return f"Думаю над ответом через {provider}" - return "Думаю над ответом" + return "Думаю над ответом." if event_type == "tool_call": - return f"Вызываю инструмент: {event.get('name')}" + tool_name = str(event.get("name") or "") + if tool_name in {"read_file", "list_files", "glob_search", "grep_text", "stat_path"}: + return "Просматриваю файлы и контекст." + if tool_name in {"git_status", "git_diff"}: + return "Проверяю изменения в проекте." + if tool_name in {"exec_command"}: + return "Проверяю окружение и команды." + if tool_name in {"web_search"}: + return "Ищу нужную информацию." + return "Проверяю детали." if event_type == "tool_result": result = event.get("result", {}) if isinstance(result, dict) and "error" in result: - return f"Инструмент {event.get('name')} завершился с ошибкой" - return f"Инструмент {event.get('name')} завершён" + return "Наткнулся на проблему, перепроверяю." + return "Собрал нужные данные." if event_type == "error": return f"Ошибка: {event.get('message')}" if event_type == "approval_result": status = event.get("status") tool_name = event.get("tool_name") if status == "approved": - return f"Подтверждение получено для {tool_name}" - return f"Подтверждение отклонено для {tool_name}" + return f"Получил подтверждение для {tool_name}." + return f"Подтверждение для {tool_name} отклонено." return None @@ -183,6 +248,83 @@ def format_approval_keyboard(approval_id: str) -> dict[str, Any]: } +def format_provider_keyboard(catalog: dict[str, Any]) -> dict[str, Any]: + rows: list[list[dict[str, str]]] = [] + row: list[dict[str, str]] = [] + for item in catalog.get("providers", []): + name = str(item.get("name") or "") + if not name: + continue + row.append({"text": name, "callback_data": f"provider:{name}"}) + if len(row) == 2: + rows.append(row) + row = [] + if row: + rows.append(row) + return {"inline_keyboard": rows} + + +def format_model_keyboard(provider: str, models: list[dict[str, str]]) -> dict[str, Any]: + rows = [ + [{"text": str(item.get("name") or item.get("id") or "-"), "callback_data": f"model:{provider}:{item.get('id')}"}] + for item in models + if item.get("id") + ] + return {"inline_keyboard": rows} + + +def set_chat_provider( + state: dict[str, Any], + chat_id: int, + provider: str | None, + *, + reset_model: bool = True, +) -> dict[str, Any]: + prefs = get_chat_preferences(state, chat_id) + prefs["provider"] = provider + if reset_model: + prefs["model"] = None + return prefs + + +def set_chat_model(state: dict[str, Any], chat_id: int, model: str | None) -> dict[str, Any]: + prefs = get_chat_preferences(state, chat_id) + prefs["model"] = model + return prefs + + +def format_provider_selection( + catalog: dict[str, Any], + selected_provider: str | None, +) -> str: + lines = ["Выбери провайдера."] + for item in catalog.get("providers", []): + name = str(item.get("name") or "") + if not name: + continue + marker = "•" if name == selected_provider else " " + availability = "ready" if item.get("available") else f"unavailable: {item.get('reason') or 'unknown'}" + lines.append(f"{marker} {name} ({availability})") + return "\n".join(lines) + + +def format_model_selection( + provider: str, + models: list[dict[str, str]], + selected_model: str | None, +) -> str: + lines = [f"Выбери модель для {provider}."] + for item in models: + model_id = str(item.get("id") or "") + if not model_id: + continue + marker = "•" if model_id == selected_model else " " + description = str(item.get("description") or "").strip() + suffix = f" ({description})" if description else "" + lines.append(f"{marker} {model_id}{suffix}") + return "\n".join(lines) + + def format_approval_request(event: dict[str, Any]) -> str: return ( "Нужно подтверждение инструмента.\n" @@ -290,7 +432,9 @@ def start_chat_job( delayed: bool = False, ) -> None: session_id = state.setdefault("sessions", {}).get(session_key) - prefix = "Обрабатываю отложенный запрос..." if delayed else "Обрабатываю запрос..." + catalog = get_provider_catalog(config) + provider, model = get_selected_provider_and_model(state, chat_id, catalog) + prefix = "Возвращаюсь к вашему сообщению." if delayed else "Смотрю, что можно сделать." status_message_id = api.send_message(chat_id, prefix) start_result = post_json( f"{config.server_url}/api/v1/chat/start", @@ -298,6 +442,8 @@ def start_chat_job( "session_id": session_id, "user_id": user_id, "message": text, + "provider": provider, + "model": model, }, ) state["sessions"][session_key] = start_result["session_id"] @@ -310,6 +456,12 @@ def start_chat_job( "seen_seq": 0, "status_message_id": status_message_id, "status_message_text": prefix, + "status_body_text": prefix, + "requested_provider": provider, + "requested_model": model, + "provider": start_result.get("provider") or provider, + "model": start_result.get("model") or model, + "fallback_notified": False, "last_typing_at": 0.0, } state.setdefault("chat_active_jobs", {})[str(chat_id)] = start_result["job_id"] @@ -501,8 +653,25 @@ def process_active_jobs( reply_markup=format_approval_keyboard(str(event["approval_id"])), ) continue + if event.get("type") == "model_request": + selection_reason = str(event.get("selection_reason") or "") + if ( + "fell back to" in selection_reason + and not job_state.get("fallback_notified") + ): + requested_provider = job_state.get("requested_provider") or "unknown" + api.send_message( + int(job_state["chat_id"]), + "Выбранный провайдер не ответил, поэтому продолжаю через запасной вариант.\n" + f"Изначально был выбран: {requested_provider}", + ) + job_state["fallback_notified"] = True + if event.get("provider"): + job_state["provider"] = event.get("provider") + if event.get("model"): + job_state["model"] = event.get("model") summary = summarize_event(event) - if summary and summary != job_state.get("status_message_text"): + if summary and summary != job_state.get("status_body_text"): update_status_message(api, job_state, summary) status = poll_result.get("status") @@ -590,7 +759,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m if text == "/start": api.send_message( chat_id, - "new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /cancel, /clear, /approve, /reject.", + "new-qwen bot готов.\nКоманды доступны через Telegram menu: /help, /auth, /status, /provider, /model, /session, /cancel, /clear.", ) return @@ -601,6 +770,8 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m "/auth - начать Qwen OAuth\n" "/auth_check [flow_id] - проверить авторизацию\n" "/status - статус OAuth и сервера\n" + "/provider [name] - выбрать провайдера\n" + "/model [id] - выбрать модель\n" "/session - показать текущую сессию\n" "/cancel - отменить активный запрос и очистить очередь\n" "/approve [approval_id] - подтвердить инструмент\n" @@ -652,10 +823,16 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m if text == "/status": status = get_json(f"{config.server_url}/api/v1/auth/status") - models_status = get_json(f"{config.server_url}/api/v1/models") + catalog = get_provider_catalog(config) queue_size = len(state.setdefault("chat_queues", {}).get(str(chat_id), [])) active_job = state.setdefault("chat_active_jobs", {}).get(str(chat_id)) - models_info = "\n".join([f" - {m['id']}: {m['name']} ({m['description']})" for m in models_status.get("models", [])]) or " Нет доступных моделей" + provider, model = get_selected_provider_and_model(state, chat_id, catalog) + providers_info = "\n".join( + [ + f" - {item['name']}: model={item.get('model')}, available={item.get('available')}" + for item in catalog.get("providers", []) + ] + ) or " Нет доступных провайдеров" send_text_chunks( api, chat_id, @@ -665,15 +842,61 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m f"available_providers: {', '.join(status.get('available_providers') or []) or '-'}\n" f"default_provider: {status.get('default_provider')}\n" f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n" + f"selected_provider: {provider}\n" + f"selected_model: {model}\n" f"resource_url: {status.get('resource_url')}\n" f"expires_at: {status.get('expires_at')}\n" f"tool_policy: {status.get('tool_policy')}\n" - f"Доступные Qwen OAuth модели:\n{models_info}\n" + f"Провайдеры и модели:\n{providers_info}\n" f"active_job: {active_job}\n" f"queued_messages: {queue_size}", ) return + if text.startswith("/provider"): + catalog = get_provider_catalog(config) + parts = text.split(maxsplit=1) + if len(parts) == 1: + provider, _ = get_selected_provider_and_model(state, chat_id, catalog) + api.send_message( + chat_id, + format_provider_selection(catalog, provider), + reply_markup=format_provider_keyboard(catalog), + ) + return + provider_name = parts[1].strip() + provider_names = {str(item.get("name") or "") for item in catalog.get("providers", [])} + if provider_name not in provider_names: + api.send_message(chat_id, f"Неизвестный provider: {provider_name}") + return + set_chat_provider(state, chat_id, provider_name, reset_model=True) + _, model = get_selected_provider_and_model(state, chat_id, catalog) + api.send_message(chat_id, f"Выбран provider: {provider_name}\nmodel: {model or '-'}") + return + + if text.startswith("/model"): + catalog = get_provider_catalog(config) + provider, selected_model = get_selected_provider_and_model(state, chat_id, catalog) + provider_map = {str(item.get("name") or ""): item for item in catalog.get("providers", [])} + provider_info = provider_map.get(provider or "") + models = provider_info.get("models", []) if provider_info else [] + parts = text.split(maxsplit=1) + if len(parts) == 1: + api.send_message( + chat_id, + format_model_selection(provider or "-", models, selected_model), + reply_markup=format_model_keyboard(provider or "-", models), + ) + return + model_id = parts[1].strip() + valid_model_ids = {str(item.get("id") or "") for item in models} + if model_id not in valid_model_ids: + api.send_message(chat_id, f"Неизвестная модель для {provider}: {model_id}") + return + set_chat_model(state, chat_id, model_id) + api.send_message(chat_id, f"Выбрана модель: {model_id}\nprovider: {provider}") + return + if text == "/cancel": canceled = cancel_chat_work( api, @@ -750,6 +973,44 @@ def handle_callback_query( action, approval_id = data.split(":", 1) if action not in {"approve", "reject"}: + if action == "provider": + catalog = get_provider_catalog(config) + provider_names = {str(item.get("name") or "") for item in catalog.get("providers", [])} + if approval_id not in provider_names: + api.answer_callback_query(callback_query_id, "Неизвестный провайдер") + return + set_chat_provider(state, chat_id, approval_id, reset_model=True) + provider, model = get_selected_provider_and_model(state, chat_id, catalog) + api.edit_message_text( + chat_id, + message_id, + f"Выбран provider: {provider}\nmodel: {model or '-'}", + reply_markup={"inline_keyboard": []}, + ) + api.answer_callback_query(callback_query_id, "Провайдер переключен") + return + if action == "model": + provider, model_id = approval_id.split(":", 1) if ":" in approval_id else ("", "") + catalog = get_provider_catalog(config) + provider_map = {str(item.get("name") or ""): item for item in catalog.get("providers", [])} + provider_info = provider_map.get(provider) + model_ids = { + str(item.get("id") or "") + for item in (provider_info.get("models", []) if provider_info else []) + } + if model_id not in model_ids: + api.answer_callback_query(callback_query_id, "Неизвестная модель") + return + set_chat_provider(state, chat_id, provider, reset_model=False) + set_chat_model(state, chat_id, model_id) + api.edit_message_text( + chat_id, + message_id, + f"Выбрана модель: {model_id}\nprovider: {provider}", + reply_markup={"inline_keyboard": []}, + ) + api.answer_callback_query(callback_query_id, "Модель переключена") + return api.answer_callback_query(callback_query_id, "Неизвестное действие") return @@ -782,6 +1043,8 @@ def main() -> None: config = BotConfig.load() api = TelegramAPI(config.token, proxy_url=config.proxy_url) state = load_state() + ensure_bot_commands(api, state) + save_state(state) print("new-qwen bot polling started") while True: try: diff --git a/bot/telegram_api.py b/bot/telegram_api.py index 28f3b83..2f9151a 100644 --- a/bot/telegram_api.py +++ b/bot/telegram_api.py @@ -119,3 +119,6 @@ class TelegramAPI: if text: payload["text"] = text self._post("answerCallbackQuery", payload) + + def set_my_commands(self, commands: list[dict[str, str]]) -> None: + self._post("setMyCommands", {"commands": commands}) diff --git a/serv/app.py b/serv/app.py index 50015a6..492699b 100644 --- a/serv/app.py +++ b/serv/app.py @@ -139,6 +139,13 @@ class AppState: "pending_approvals": len(self.approvals.list_pending()), } + def provider_catalog(self) -> dict[str, Any]: + return { + "default_provider": self.config.default_provider, + "default_model": self.config.model, + "providers": self.providers.statuses(), + } + class RequestHandler(BaseHTTPRequestHandler): server_version = "new-qwen-serv/0.1" @@ -180,6 +187,9 @@ class RequestHandler(BaseHTTPRequestHandler): models = self.app.oauth.get_available_models() self._send(HTTPStatus.OK, {"models": models}) return + if self.path == "/api/v1/provider-catalog": + self._send(HTTPStatus.OK, self.app.provider_catalog()) + return self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"}) def _run_chat_job( @@ -189,6 +199,7 @@ class RequestHandler(BaseHTTPRequestHandler): user_id: str, message: str, preferred_provider: str | None = None, + preferred_model: str | None = None, ) -> None: try: if self.app.jobs.is_cancel_requested(job_id): @@ -217,6 +228,7 @@ class RequestHandler(BaseHTTPRequestHandler): ), is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id), preferred_provider=preferred_provider, + preferred_model=preferred_model, ) if self.app.jobs.is_cancel_requested(job_id): reason = "Job canceled by operator" @@ -338,12 +350,14 @@ class RequestHandler(BaseHTTPRequestHandler): user_id = str(body.get("user_id") or "anonymous") message = body["message"] preferred_provider = body.get("provider") + preferred_model = body.get("model") session = self.app.sessions.load(session_id) history = session.get("messages", []) result = self.app.agent.run( history, message, preferred_provider=preferred_provider, + preferred_model=preferred_model, ) persisted_messages = result["messages"][1:] self.app.sessions.save( @@ -375,10 +389,11 @@ class RequestHandler(BaseHTTPRequestHandler): user_id = str(body.get("user_id") or "anonymous") message = body["message"] preferred_provider = body.get("provider") + preferred_model = body.get("model") job = self.app.jobs.create(session_id, user_id, message) thread = threading.Thread( target=self._run_chat_job, - args=(job["job_id"], session_id, user_id, message, preferred_provider), + args=(job["job_id"], session_id, user_id, message, preferred_provider, preferred_model), daemon=True, ) thread.start() @@ -389,6 +404,7 @@ class RequestHandler(BaseHTTPRequestHandler): "session_id": session_id, "status": "queued", "provider": preferred_provider or self.app.config.default_provider, + "model": preferred_model or self.app.config.model, }, ) return diff --git a/serv/gigachat.py b/serv/gigachat.py index 76acad1..5aae6ec 100644 --- a/serv/gigachat.py +++ b/serv/gigachat.py @@ -55,6 +55,14 @@ class GigaChatAuthManager: encoding="utf-8", ) + @staticmethod + def _normalize_expires_at(value: Any) -> int: + expires_at = int(value or 0) + # Sber can return unix time in milliseconds. + if expires_at > 10_000_000_000: + return expires_at // 1000 + return expires_at + def fetch_token(self) -> dict[str, Any]: data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8") req = request.Request( @@ -76,7 +84,7 @@ class GigaChatAuthManager: raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc token = { "access_token": payload["access_token"], - "expires_at": int(payload["expires_at"]), + "expires_at": self._normalize_expires_at(payload.get("expires_at")), } self.save_token(token) return token @@ -86,7 +94,8 @@ class GigaChatAuthManager: raise GigaChatError("GigaChat auth key is not configured") token = self.load_token() now = int(time.time()) - if token and int(token.get("expires_at", 0)) - now > 30: + expires_at = self._normalize_expires_at(token.get("expires_at", 0)) if token else 0 + if token and expires_at - now > 30: return str(token["access_token"]) refreshed = self.fetch_token() return str(refreshed["access_token"]) diff --git a/serv/llm.py b/serv/llm.py index 439ff53..5502249 100644 --- a/serv/llm.py +++ b/serv/llm.py @@ -30,6 +30,7 @@ class QwenAgent: approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None, is_cancelled: Callable[[], bool] | None = None, preferred_provider: str | None = None, + preferred_model: str | None = None, ) -> dict[str, Any]: emit = on_event or (lambda _event: None) cancel_check = is_cancelled or (lambda: False) @@ -54,6 +55,7 @@ class QwenAgent: tools=self.tools.schemas(), tool_choice="auto", preferred_provider=preferred_provider, + preferred_model=preferred_model, require_tools=True, ) ) diff --git a/serv/model_router.py b/serv/model_router.py index c57710a..6dceabc 100644 --- a/serv/model_router.py +++ b/serv/model_router.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from dataclasses import replace from dataclasses import dataclass, field from typing import Any import uuid @@ -28,6 +29,7 @@ class CompletionRequest: tools: list[dict[str, Any]] tool_choice: str = "auto" preferred_provider: str | None = None + preferred_model: str | None = None require_tools: bool = True @@ -53,9 +55,14 @@ class BaseModelProvider: def unavailable_reason(self) -> str | None: raise NotImplementedError - def model_name(self) -> str: + def model_name(self, preferred_model: str | None = None) -> str: + del preferred_model raise NotImplementedError + def list_models(self) -> list[dict[str, str]]: + model = self.model_name() + return [{"id": model, "name": model, "description": ""}] + def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: raise NotImplementedError @@ -83,7 +90,8 @@ class UnavailableModelProvider(BaseModelProvider): def unavailable_reason(self) -> str | None: return self._reason - def model_name(self) -> str: + def model_name(self, preferred_model: str | None = None) -> str: + del preferred_model return self._model_name def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: @@ -114,8 +122,15 @@ class QwenModelProvider(BaseModelProvider): return None return "Qwen OAuth is not configured" - def model_name(self) -> str: - return self._model_name + def model_name(self, preferred_model: str | None = None) -> str: + if preferred_model and preferred_model not in QWEN_OAUTH_ALLOWED_MODELS: + raise ModelProviderError( + f"Provider {self.name} does not support model '{preferred_model}'" + ) + return self.oauth.get_model_name_for_id(preferred_model or self._model_id) + + def list_models(self) -> list[dict[str, str]]: + return self.oauth.get_available_models() @staticmethod def _normalize_content(value: Any) -> list[dict[str, str]]: @@ -140,8 +155,9 @@ class QwenModelProvider(BaseModelProvider): def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: creds = self.oauth.get_valid_credentials() base_url = self.oauth.get_openai_base_url(creds) + model_name = self.model_name(completion_request.preferred_model) payload = { - "model": self.model_name(), + "model": model_name, "messages": self._normalize_messages(completion_request.messages), "max_tokens": 8000, "metadata": { @@ -203,8 +219,18 @@ class GigaChatModelProvider(BaseModelProvider): return None return "GigaChat auth key is not configured" - def model_name(self) -> str: - return self.config.gigachat_model + def model_name(self, preferred_model: str | None = None) -> str: + return (preferred_model or self.config.gigachat_model).strip() or self.config.gigachat_model + + def list_models(self) -> list[dict[str, str]]: + model = self.config.gigachat_model + return [ + { + "id": model, + "name": model, + "description": "Configured GigaChat model", + } + ] @staticmethod def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]: @@ -294,7 +320,7 @@ class GigaChatModelProvider(BaseModelProvider): raise ModelProviderError(str(exc)) from exc api_base = self.config.gigachat_api_base_url.rstrip("/") payload: dict[str, Any] = { - "model": self.model_name(), + "model": self.model_name(completion_request.preferred_model), "messages": self._convert_messages(completion_request.messages), } if completion_request.tools: @@ -340,6 +366,7 @@ class ProviderRegistry: { "name": provider.name, "model": provider.model_name(), + "models": provider.list_models(), "available": provider.is_available(), "reason": provider.unavailable_reason(), "capabilities": { @@ -400,13 +427,16 @@ class ModelRouter: elif name == self.config.default_provider: selection_reason = "selected default provider" try: - payload = provider.complete(completion_request) + provider_request = completion_request + if completion_request.preferred_provider and name != completion_request.preferred_provider: + provider_request = replace(completion_request, preferred_model=None) + payload = provider.complete(provider_request) except ModelProviderError as exc: reasons.append(f"{name}: request failed: {exc}") continue return CompletionResponse( provider_name=name, - model_name=provider.model_name(), + model_name=provider.model_name(provider_request.preferred_model), payload=payload, selection_reason=selection_reason, attempted=attempted,