Improve provider selection UX and fix GigaChat token refresh

This commit is contained in:
mirivlad 2026-04-09 04:14:31 +08:00
parent b8c55a8194
commit 99d0be151f
6 changed files with 357 additions and 34 deletions

View File

@ -14,6 +14,16 @@ from telegram_api import TelegramAPI
STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-state.json" STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-state.json"
TYPING_INTERVAL_SECONDS = 4.0 TYPING_INTERVAL_SECONDS = 4.0
BOT_COMMANDS = [
{"command": "help", "description": "Список команд"},
{"command": "status", "description": "Статус сервера и чата"},
{"command": "auth", "description": "Запустить Qwen OAuth"},
{"command": "provider", "description": "Выбрать провайдера"},
{"command": "model", "description": "Выбрать модель"},
{"command": "session", "description": "Показать сессию"},
{"command": "clear", "description": "Очистить контекст"},
{"command": "cancel", "description": "Отменить активный job"},
]
def load_state() -> dict[str, Any]: def load_state() -> dict[str, Any]:
@ -26,6 +36,7 @@ def load_state() -> dict[str, Any]:
"chat_active_jobs": {}, "chat_active_jobs": {},
"chat_queues": {}, "chat_queues": {},
"pending_approvals": {}, "pending_approvals": {},
"chat_preferences": {},
} }
state = json.loads(STATE_FILE.read_text(encoding="utf-8")) state = json.loads(STATE_FILE.read_text(encoding="utf-8"))
state.setdefault("sessions", {}) state.setdefault("sessions", {})
@ -34,6 +45,7 @@ def load_state() -> dict[str, Any]:
state.setdefault("chat_active_jobs", {}) state.setdefault("chat_active_jobs", {})
state.setdefault("chat_queues", {}) state.setdefault("chat_queues", {})
state.setdefault("pending_approvals", {}) state.setdefault("pending_approvals", {})
state.setdefault("chat_preferences", {})
return state return state
@ -66,6 +78,49 @@ def send_text_chunks(api: TelegramAPI, chat_id: int, text: str) -> None:
api.send_message(chat_id, normalized[start : start + chunk_size]) api.send_message(chat_id, normalized[start : start + chunk_size])
def ensure_bot_commands(api: TelegramAPI, state: dict[str, Any]) -> None:
api.set_my_commands(BOT_COMMANDS)
def get_provider_catalog(config: BotConfig) -> dict[str, Any]:
return get_json(f"{config.server_url}/api/v1/provider-catalog")
def get_chat_preferences(state: dict[str, Any], chat_id: int) -> dict[str, Any]:
prefs = state.setdefault("chat_preferences", {}).setdefault(str(chat_id), {})
prefs.setdefault("provider", None)
prefs.setdefault("model", None)
return prefs
def get_selected_provider_and_model(
state: dict[str, Any],
chat_id: int,
catalog: dict[str, Any],
) -> tuple[str | None, str | None]:
prefs = get_chat_preferences(state, chat_id)
provider = prefs.get("provider") or catalog.get("default_provider")
model = prefs.get("model")
providers = {item.get("name"): item for item in catalog.get("providers", [])}
provider_info = providers.get(provider or "")
available_models = provider_info.get("models", []) if provider_info else []
if not model and available_models:
model = available_models[0].get("id")
return provider, model
def format_status_text(job_state: dict[str, Any], text: str) -> str:
del job_state
return (text or "Пустой ответ.")[:4000]
def format_final_signature(job_state: dict[str, Any]) -> str:
model = job_state.get("model") if isinstance(job_state.get("model"), str) else None
if not model:
return ""
return f"\n\n<i>{html.escape(model)}</i>"
def render_markdownish_html(text: str) -> str: def render_markdownish_html(text: str) -> str:
normalized = (text or "Пустой ответ.").replace("\r\n", "\n") normalized = (text or "Пустой ответ.").replace("\r\n", "\n")
fence_pattern = re.compile(r"```(?:[^\n`]*)\n(.*?)```", re.DOTALL) fence_pattern = re.compile(r"```(?:[^\n`]*)\n(.*?)```", re.DOTALL)
@ -91,7 +146,7 @@ def update_status_message(
job_state: dict[str, Any], job_state: dict[str, Any],
text: str, text: str,
) -> None: ) -> None:
normalized = (text or "Пустой ответ.")[:4000] normalized = format_status_text(job_state, text)
chat_id = int(job_state["chat_id"]) chat_id = int(job_state["chat_id"])
message_id = int(job_state.get("status_message_id") or 0) message_id = int(job_state.get("status_message_id") or 0)
if message_id: if message_id:
@ -100,6 +155,7 @@ def update_status_message(
message_id = api.send_message(chat_id, normalized) message_id = api.send_message(chat_id, normalized)
job_state["status_message_id"] = message_id job_state["status_message_id"] = message_id
job_state["status_message_text"] = normalized job_state["status_message_text"] = normalized
job_state["status_body_text"] = text or "Пустой ответ."
def set_final_message( def set_final_message(
@ -109,8 +165,9 @@ def set_final_message(
) -> None: ) -> None:
normalized = text or "Пустой ответ." normalized = text or "Пустой ответ."
chunk_size = 3800 chunk_size = 3800
signature = format_final_signature(job_state)
first_chunk = normalized[:chunk_size] first_chunk = normalized[:chunk_size]
rendered = render_markdownish_html(first_chunk) rendered = render_markdownish_html(first_chunk) + signature
chat_id = int(job_state["chat_id"]) chat_id = int(job_state["chat_id"])
message_id = int(job_state.get("status_message_id") or 0) message_id = int(job_state.get("status_message_id") or 0)
if message_id: if message_id:
@ -145,30 +202,38 @@ def keep_job_typing(
def summarize_event(event: dict[str, Any]) -> str | None: def summarize_event(event: dict[str, Any]) -> str | None:
event_type = event.get("type") event_type = event.get("type")
if event_type == "job_status": if event_type == "job_status":
return event.get("message") message = str(event.get("message") or "").strip()
if message == "Запрос принят сервером":
return "Смотрю, что можно сделать."
if message == "Ответ готов":
return "Формулирую ответ."
return message or None
if event_type == "model_request": if event_type == "model_request":
provider = event.get("provider") return "Думаю над ответом."
model = event.get("model")
if provider and model:
return f"Думаю над ответом через {provider}/{model}"
if provider:
return f"Думаю над ответом через {provider}"
return "Думаю над ответом"
if event_type == "tool_call": if event_type == "tool_call":
return f"Вызываю инструмент: {event.get('name')}" tool_name = str(event.get("name") or "")
if tool_name in {"read_file", "list_files", "glob_search", "grep_text", "stat_path"}:
return "Просматриваю файлы и контекст."
if tool_name in {"git_status", "git_diff"}:
return "Проверяю изменения в проекте."
if tool_name in {"exec_command"}:
return "Проверяю окружение и команды."
if tool_name in {"web_search"}:
return "Ищу нужную информацию."
return "Проверяю детали."
if event_type == "tool_result": if event_type == "tool_result":
result = event.get("result", {}) result = event.get("result", {})
if isinstance(result, dict) and "error" in result: if isinstance(result, dict) and "error" in result:
return f"Инструмент {event.get('name')} завершился с ошибкой" return "Наткнулся на проблему, перепроверяю."
return f"Инструмент {event.get('name')} завершён" return "Собрал нужные данные."
if event_type == "error": if event_type == "error":
return f"Ошибка: {event.get('message')}" return f"Ошибка: {event.get('message')}"
if event_type == "approval_result": if event_type == "approval_result":
status = event.get("status") status = event.get("status")
tool_name = event.get("tool_name") tool_name = event.get("tool_name")
if status == "approved": if status == "approved":
return f"Подтверждение получено для {tool_name}" return f"Получил подтверждение для {tool_name}."
return f"Подтверждение отклонено для {tool_name}" return f"Подтверждение для {tool_name} отклонено."
return None return None
@ -183,6 +248,83 @@ def format_approval_keyboard(approval_id: str) -> dict[str, Any]:
} }
def format_provider_keyboard(catalog: dict[str, Any]) -> dict[str, Any]:
rows: list[list[dict[str, str]]] = []
row: list[dict[str, str]] = []
for item in catalog.get("providers", []):
name = str(item.get("name") or "")
if not name:
continue
row.append({"text": name, "callback_data": f"provider:{name}"})
if len(row) == 2:
rows.append(row)
row = []
if row:
rows.append(row)
return {"inline_keyboard": rows}
def format_model_keyboard(provider: str, models: list[dict[str, str]]) -> dict[str, Any]:
rows = [
[{"text": str(item.get("name") or item.get("id") or "-"), "callback_data": f"model:{provider}:{item.get('id')}"}]
for item in models
if item.get("id")
]
return {"inline_keyboard": rows}
def set_chat_provider(
state: dict[str, Any],
chat_id: int,
provider: str | None,
*,
reset_model: bool = True,
) -> dict[str, Any]:
prefs = get_chat_preferences(state, chat_id)
prefs["provider"] = provider
if reset_model:
prefs["model"] = None
return prefs
def set_chat_model(state: dict[str, Any], chat_id: int, model: str | None) -> dict[str, Any]:
prefs = get_chat_preferences(state, chat_id)
prefs["model"] = model
return prefs
def format_provider_selection(
catalog: dict[str, Any],
selected_provider: str | None,
) -> str:
lines = ["Выбери провайдера."]
for item in catalog.get("providers", []):
name = str(item.get("name") or "")
if not name:
continue
marker = "" if name == selected_provider else " "
availability = "ready" if item.get("available") else f"unavailable: {item.get('reason') or 'unknown'}"
lines.append(f"{marker} {name} ({availability})")
return "\n".join(lines)
def format_model_selection(
provider: str,
models: list[dict[str, str]],
selected_model: str | None,
) -> str:
lines = [f"Выбери модель для {provider}."]
for item in models:
model_id = str(item.get("id") or "")
if not model_id:
continue
marker = "" if model_id == selected_model else " "
description = str(item.get("description") or "").strip()
suffix = f" ({description})" if description else ""
lines.append(f"{marker} {model_id}{suffix}")
return "\n".join(lines)
def format_approval_request(event: dict[str, Any]) -> str: def format_approval_request(event: dict[str, Any]) -> str:
return ( return (
"Нужно подтверждение инструмента.\n" "Нужно подтверждение инструмента.\n"
@ -290,7 +432,9 @@ def start_chat_job(
delayed: bool = False, delayed: bool = False,
) -> None: ) -> None:
session_id = state.setdefault("sessions", {}).get(session_key) session_id = state.setdefault("sessions", {}).get(session_key)
prefix = "Обрабатываю отложенный запрос..." if delayed else "Обрабатываю запрос..." catalog = get_provider_catalog(config)
provider, model = get_selected_provider_and_model(state, chat_id, catalog)
prefix = "Возвращаюсь к вашему сообщению." if delayed else "Смотрю, что можно сделать."
status_message_id = api.send_message(chat_id, prefix) status_message_id = api.send_message(chat_id, prefix)
start_result = post_json( start_result = post_json(
f"{config.server_url}/api/v1/chat/start", f"{config.server_url}/api/v1/chat/start",
@ -298,6 +442,8 @@ def start_chat_job(
"session_id": session_id, "session_id": session_id,
"user_id": user_id, "user_id": user_id,
"message": text, "message": text,
"provider": provider,
"model": model,
}, },
) )
state["sessions"][session_key] = start_result["session_id"] state["sessions"][session_key] = start_result["session_id"]
@ -310,6 +456,12 @@ def start_chat_job(
"seen_seq": 0, "seen_seq": 0,
"status_message_id": status_message_id, "status_message_id": status_message_id,
"status_message_text": prefix, "status_message_text": prefix,
"status_body_text": prefix,
"requested_provider": provider,
"requested_model": model,
"provider": start_result.get("provider") or provider,
"model": start_result.get("model") or model,
"fallback_notified": False,
"last_typing_at": 0.0, "last_typing_at": 0.0,
} }
state.setdefault("chat_active_jobs", {})[str(chat_id)] = start_result["job_id"] state.setdefault("chat_active_jobs", {})[str(chat_id)] = start_result["job_id"]
@ -501,8 +653,25 @@ def process_active_jobs(
reply_markup=format_approval_keyboard(str(event["approval_id"])), reply_markup=format_approval_keyboard(str(event["approval_id"])),
) )
continue continue
if event.get("type") == "model_request":
selection_reason = str(event.get("selection_reason") or "")
if (
"fell back to" in selection_reason
and not job_state.get("fallback_notified")
):
requested_provider = job_state.get("requested_provider") or "unknown"
api.send_message(
int(job_state["chat_id"]),
"Выбранный провайдер не ответил, поэтому продолжаю через запасной вариант.\n"
f"Изначально был выбран: {requested_provider}",
)
job_state["fallback_notified"] = True
if event.get("provider"):
job_state["provider"] = event.get("provider")
if event.get("model"):
job_state["model"] = event.get("model")
summary = summarize_event(event) summary = summarize_event(event)
if summary and summary != job_state.get("status_message_text"): if summary and summary != job_state.get("status_body_text"):
update_status_message(api, job_state, summary) update_status_message(api, job_state, summary)
status = poll_result.get("status") status = poll_result.get("status")
@ -590,7 +759,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
if text == "/start": if text == "/start":
api.send_message( api.send_message(
chat_id, chat_id,
"new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /cancel, /clear, /approve, /reject.", "new-qwen bot готов.\nКоманды доступны через Telegram menu: /help, /auth, /status, /provider, /model, /session, /cancel, /clear.",
) )
return return
@ -601,6 +770,8 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
"/auth - начать Qwen OAuth\n" "/auth - начать Qwen OAuth\n"
"/auth_check [flow_id] - проверить авторизацию\n" "/auth_check [flow_id] - проверить авторизацию\n"
"/status - статус OAuth и сервера\n" "/status - статус OAuth и сервера\n"
"/provider [name] - выбрать провайдера\n"
"/model [id] - выбрать модель\n"
"/session - показать текущую сессию\n" "/session - показать текущую сессию\n"
"/cancel - отменить активный запрос и очистить очередь\n" "/cancel - отменить активный запрос и очистить очередь\n"
"/approve [approval_id] - подтвердить инструмент\n" "/approve [approval_id] - подтвердить инструмент\n"
@ -652,10 +823,16 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
if text == "/status": if text == "/status":
status = get_json(f"{config.server_url}/api/v1/auth/status") status = get_json(f"{config.server_url}/api/v1/auth/status")
models_status = get_json(f"{config.server_url}/api/v1/models") catalog = get_provider_catalog(config)
queue_size = len(state.setdefault("chat_queues", {}).get(str(chat_id), [])) queue_size = len(state.setdefault("chat_queues", {}).get(str(chat_id), []))
active_job = state.setdefault("chat_active_jobs", {}).get(str(chat_id)) active_job = state.setdefault("chat_active_jobs", {}).get(str(chat_id))
models_info = "\n".join([f" - {m['id']}: {m['name']} ({m['description']})" for m in models_status.get("models", [])]) or " Нет доступных моделей" provider, model = get_selected_provider_and_model(state, chat_id, catalog)
providers_info = "\n".join(
[
f" - {item['name']}: model={item.get('model')}, available={item.get('available')}"
for item in catalog.get("providers", [])
]
) or " Нет доступных провайдеров"
send_text_chunks( send_text_chunks(
api, api,
chat_id, chat_id,
@ -665,15 +842,61 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
f"available_providers: {', '.join(status.get('available_providers') or []) or '-'}\n" f"available_providers: {', '.join(status.get('available_providers') or []) or '-'}\n"
f"default_provider: {status.get('default_provider')}\n" f"default_provider: {status.get('default_provider')}\n"
f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n" f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n"
f"selected_provider: {provider}\n"
f"selected_model: {model}\n"
f"resource_url: {status.get('resource_url')}\n" f"resource_url: {status.get('resource_url')}\n"
f"expires_at: {status.get('expires_at')}\n" f"expires_at: {status.get('expires_at')}\n"
f"tool_policy: {status.get('tool_policy')}\n" f"tool_policy: {status.get('tool_policy')}\n"
f"Доступные Qwen OAuth модели:\n{models_info}\n" f"Провайдеры и модели:\n{providers_info}\n"
f"active_job: {active_job}\n" f"active_job: {active_job}\n"
f"queued_messages: {queue_size}", f"queued_messages: {queue_size}",
) )
return return
if text.startswith("/provider"):
catalog = get_provider_catalog(config)
parts = text.split(maxsplit=1)
if len(parts) == 1:
provider, _ = get_selected_provider_and_model(state, chat_id, catalog)
api.send_message(
chat_id,
format_provider_selection(catalog, provider),
reply_markup=format_provider_keyboard(catalog),
)
return
provider_name = parts[1].strip()
provider_names = {str(item.get("name") or "") for item in catalog.get("providers", [])}
if provider_name not in provider_names:
api.send_message(chat_id, f"Неизвестный provider: {provider_name}")
return
set_chat_provider(state, chat_id, provider_name, reset_model=True)
_, model = get_selected_provider_and_model(state, chat_id, catalog)
api.send_message(chat_id, f"Выбран provider: {provider_name}\nmodel: {model or '-'}")
return
if text.startswith("/model"):
catalog = get_provider_catalog(config)
provider, selected_model = get_selected_provider_and_model(state, chat_id, catalog)
provider_map = {str(item.get("name") or ""): item for item in catalog.get("providers", [])}
provider_info = provider_map.get(provider or "")
models = provider_info.get("models", []) if provider_info else []
parts = text.split(maxsplit=1)
if len(parts) == 1:
api.send_message(
chat_id,
format_model_selection(provider or "-", models, selected_model),
reply_markup=format_model_keyboard(provider or "-", models),
)
return
model_id = parts[1].strip()
valid_model_ids = {str(item.get("id") or "") for item in models}
if model_id not in valid_model_ids:
api.send_message(chat_id, f"Неизвестная модель для {provider}: {model_id}")
return
set_chat_model(state, chat_id, model_id)
api.send_message(chat_id, f"Выбрана модель: {model_id}\nprovider: {provider}")
return
if text == "/cancel": if text == "/cancel":
canceled = cancel_chat_work( canceled = cancel_chat_work(
api, api,
@ -750,6 +973,44 @@ def handle_callback_query(
action, approval_id = data.split(":", 1) action, approval_id = data.split(":", 1)
if action not in {"approve", "reject"}: if action not in {"approve", "reject"}:
if action == "provider":
catalog = get_provider_catalog(config)
provider_names = {str(item.get("name") or "") for item in catalog.get("providers", [])}
if approval_id not in provider_names:
api.answer_callback_query(callback_query_id, "Неизвестный провайдер")
return
set_chat_provider(state, chat_id, approval_id, reset_model=True)
provider, model = get_selected_provider_and_model(state, chat_id, catalog)
api.edit_message_text(
chat_id,
message_id,
f"Выбран provider: {provider}\nmodel: {model or '-'}",
reply_markup={"inline_keyboard": []},
)
api.answer_callback_query(callback_query_id, "Провайдер переключен")
return
if action == "model":
provider, model_id = approval_id.split(":", 1) if ":" in approval_id else ("", "")
catalog = get_provider_catalog(config)
provider_map = {str(item.get("name") or ""): item for item in catalog.get("providers", [])}
provider_info = provider_map.get(provider)
model_ids = {
str(item.get("id") or "")
for item in (provider_info.get("models", []) if provider_info else [])
}
if model_id not in model_ids:
api.answer_callback_query(callback_query_id, "Неизвестная модель")
return
set_chat_provider(state, chat_id, provider, reset_model=False)
set_chat_model(state, chat_id, model_id)
api.edit_message_text(
chat_id,
message_id,
f"Выбрана модель: {model_id}\nprovider: {provider}",
reply_markup={"inline_keyboard": []},
)
api.answer_callback_query(callback_query_id, "Модель переключена")
return
api.answer_callback_query(callback_query_id, "Неизвестное действие") api.answer_callback_query(callback_query_id, "Неизвестное действие")
return return
@ -782,6 +1043,8 @@ def main() -> None:
config = BotConfig.load() config = BotConfig.load()
api = TelegramAPI(config.token, proxy_url=config.proxy_url) api = TelegramAPI(config.token, proxy_url=config.proxy_url)
state = load_state() state = load_state()
ensure_bot_commands(api, state)
save_state(state)
print("new-qwen bot polling started") print("new-qwen bot polling started")
while True: while True:
try: try:

View File

@ -119,3 +119,6 @@ class TelegramAPI:
if text: if text:
payload["text"] = text payload["text"] = text
self._post("answerCallbackQuery", payload) self._post("answerCallbackQuery", payload)
def set_my_commands(self, commands: list[dict[str, str]]) -> None:
self._post("setMyCommands", {"commands": commands})

View File

@ -139,6 +139,13 @@ class AppState:
"pending_approvals": len(self.approvals.list_pending()), "pending_approvals": len(self.approvals.list_pending()),
} }
def provider_catalog(self) -> dict[str, Any]:
return {
"default_provider": self.config.default_provider,
"default_model": self.config.model,
"providers": self.providers.statuses(),
}
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
server_version = "new-qwen-serv/0.1" server_version = "new-qwen-serv/0.1"
@ -180,6 +187,9 @@ class RequestHandler(BaseHTTPRequestHandler):
models = self.app.oauth.get_available_models() models = self.app.oauth.get_available_models()
self._send(HTTPStatus.OK, {"models": models}) self._send(HTTPStatus.OK, {"models": models})
return return
if self.path == "/api/v1/provider-catalog":
self._send(HTTPStatus.OK, self.app.provider_catalog())
return
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"}) self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
def _run_chat_job( def _run_chat_job(
@ -189,6 +199,7 @@ class RequestHandler(BaseHTTPRequestHandler):
user_id: str, user_id: str,
message: str, message: str,
preferred_provider: str | None = None, preferred_provider: str | None = None,
preferred_model: str | None = None,
) -> None: ) -> None:
try: try:
if self.app.jobs.is_cancel_requested(job_id): if self.app.jobs.is_cancel_requested(job_id):
@ -217,6 +228,7 @@ class RequestHandler(BaseHTTPRequestHandler):
), ),
is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id), is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id),
preferred_provider=preferred_provider, preferred_provider=preferred_provider,
preferred_model=preferred_model,
) )
if self.app.jobs.is_cancel_requested(job_id): if self.app.jobs.is_cancel_requested(job_id):
reason = "Job canceled by operator" reason = "Job canceled by operator"
@ -338,12 +350,14 @@ class RequestHandler(BaseHTTPRequestHandler):
user_id = str(body.get("user_id") or "anonymous") user_id = str(body.get("user_id") or "anonymous")
message = body["message"] message = body["message"]
preferred_provider = body.get("provider") preferred_provider = body.get("provider")
preferred_model = body.get("model")
session = self.app.sessions.load(session_id) session = self.app.sessions.load(session_id)
history = session.get("messages", []) history = session.get("messages", [])
result = self.app.agent.run( result = self.app.agent.run(
history, history,
message, message,
preferred_provider=preferred_provider, preferred_provider=preferred_provider,
preferred_model=preferred_model,
) )
persisted_messages = result["messages"][1:] persisted_messages = result["messages"][1:]
self.app.sessions.save( self.app.sessions.save(
@ -375,10 +389,11 @@ class RequestHandler(BaseHTTPRequestHandler):
user_id = str(body.get("user_id") or "anonymous") user_id = str(body.get("user_id") or "anonymous")
message = body["message"] message = body["message"]
preferred_provider = body.get("provider") preferred_provider = body.get("provider")
preferred_model = body.get("model")
job = self.app.jobs.create(session_id, user_id, message) job = self.app.jobs.create(session_id, user_id, message)
thread = threading.Thread( thread = threading.Thread(
target=self._run_chat_job, target=self._run_chat_job,
args=(job["job_id"], session_id, user_id, message, preferred_provider), args=(job["job_id"], session_id, user_id, message, preferred_provider, preferred_model),
daemon=True, daemon=True,
) )
thread.start() thread.start()
@ -389,6 +404,7 @@ class RequestHandler(BaseHTTPRequestHandler):
"session_id": session_id, "session_id": session_id,
"status": "queued", "status": "queued",
"provider": preferred_provider or self.app.config.default_provider, "provider": preferred_provider or self.app.config.default_provider,
"model": preferred_model or self.app.config.model,
}, },
) )
return return

View File

@ -55,6 +55,14 @@ class GigaChatAuthManager:
encoding="utf-8", encoding="utf-8",
) )
@staticmethod
def _normalize_expires_at(value: Any) -> int:
expires_at = int(value or 0)
# Sber can return unix time in milliseconds.
if expires_at > 10_000_000_000:
return expires_at // 1000
return expires_at
def fetch_token(self) -> dict[str, Any]: def fetch_token(self) -> dict[str, Any]:
data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8") data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8")
req = request.Request( req = request.Request(
@ -76,7 +84,7 @@ class GigaChatAuthManager:
raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc
token = { token = {
"access_token": payload["access_token"], "access_token": payload["access_token"],
"expires_at": int(payload["expires_at"]), "expires_at": self._normalize_expires_at(payload.get("expires_at")),
} }
self.save_token(token) self.save_token(token)
return token return token
@ -86,7 +94,8 @@ class GigaChatAuthManager:
raise GigaChatError("GigaChat auth key is not configured") raise GigaChatError("GigaChat auth key is not configured")
token = self.load_token() token = self.load_token()
now = int(time.time()) now = int(time.time())
if token and int(token.get("expires_at", 0)) - now > 30: expires_at = self._normalize_expires_at(token.get("expires_at", 0)) if token else 0
if token and expires_at - now > 30:
return str(token["access_token"]) return str(token["access_token"])
refreshed = self.fetch_token() refreshed = self.fetch_token()
return str(refreshed["access_token"]) return str(refreshed["access_token"])

View File

@ -30,6 +30,7 @@ class QwenAgent:
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None, approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
is_cancelled: Callable[[], bool] | None = None, is_cancelled: Callable[[], bool] | None = None,
preferred_provider: str | None = None, preferred_provider: str | None = None,
preferred_model: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
emit = on_event or (lambda _event: None) emit = on_event or (lambda _event: None)
cancel_check = is_cancelled or (lambda: False) cancel_check = is_cancelled or (lambda: False)
@ -54,6 +55,7 @@ class QwenAgent:
tools=self.tools.schemas(), tools=self.tools.schemas(),
tool_choice="auto", tool_choice="auto",
preferred_provider=preferred_provider, preferred_provider=preferred_provider,
preferred_model=preferred_model,
require_tools=True, require_tools=True,
) )
) )

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from dataclasses import replace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
import uuid import uuid
@ -28,6 +29,7 @@ class CompletionRequest:
tools: list[dict[str, Any]] tools: list[dict[str, Any]]
tool_choice: str = "auto" tool_choice: str = "auto"
preferred_provider: str | None = None preferred_provider: str | None = None
preferred_model: str | None = None
require_tools: bool = True require_tools: bool = True
@ -53,9 +55,14 @@ class BaseModelProvider:
def unavailable_reason(self) -> str | None: def unavailable_reason(self) -> str | None:
raise NotImplementedError raise NotImplementedError
def model_name(self) -> str: def model_name(self, preferred_model: str | None = None) -> str:
del preferred_model
raise NotImplementedError raise NotImplementedError
def list_models(self) -> list[dict[str, str]]:
model = self.model_name()
return [{"id": model, "name": model, "description": ""}]
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
raise NotImplementedError raise NotImplementedError
@ -83,7 +90,8 @@ class UnavailableModelProvider(BaseModelProvider):
def unavailable_reason(self) -> str | None: def unavailable_reason(self) -> str | None:
return self._reason return self._reason
def model_name(self) -> str: def model_name(self, preferred_model: str | None = None) -> str:
del preferred_model
return self._model_name return self._model_name
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
@ -114,8 +122,15 @@ class QwenModelProvider(BaseModelProvider):
return None return None
return "Qwen OAuth is not configured" return "Qwen OAuth is not configured"
def model_name(self) -> str: def model_name(self, preferred_model: str | None = None) -> str:
return self._model_name if preferred_model and preferred_model not in QWEN_OAUTH_ALLOWED_MODELS:
raise ModelProviderError(
f"Provider {self.name} does not support model '{preferred_model}'"
)
return self.oauth.get_model_name_for_id(preferred_model or self._model_id)
def list_models(self) -> list[dict[str, str]]:
return self.oauth.get_available_models()
@staticmethod @staticmethod
def _normalize_content(value: Any) -> list[dict[str, str]]: def _normalize_content(value: Any) -> list[dict[str, str]]:
@ -140,8 +155,9 @@ class QwenModelProvider(BaseModelProvider):
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]: def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
creds = self.oauth.get_valid_credentials() creds = self.oauth.get_valid_credentials()
base_url = self.oauth.get_openai_base_url(creds) base_url = self.oauth.get_openai_base_url(creds)
model_name = self.model_name(completion_request.preferred_model)
payload = { payload = {
"model": self.model_name(), "model": model_name,
"messages": self._normalize_messages(completion_request.messages), "messages": self._normalize_messages(completion_request.messages),
"max_tokens": 8000, "max_tokens": 8000,
"metadata": { "metadata": {
@ -203,8 +219,18 @@ class GigaChatModelProvider(BaseModelProvider):
return None return None
return "GigaChat auth key is not configured" return "GigaChat auth key is not configured"
def model_name(self) -> str: def model_name(self, preferred_model: str | None = None) -> str:
return self.config.gigachat_model return (preferred_model or self.config.gigachat_model).strip() or self.config.gigachat_model
def list_models(self) -> list[dict[str, str]]:
model = self.config.gigachat_model
return [
{
"id": model,
"name": model,
"description": "Configured GigaChat model",
}
]
@staticmethod @staticmethod
def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]: def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
@ -294,7 +320,7 @@ class GigaChatModelProvider(BaseModelProvider):
raise ModelProviderError(str(exc)) from exc raise ModelProviderError(str(exc)) from exc
api_base = self.config.gigachat_api_base_url.rstrip("/") api_base = self.config.gigachat_api_base_url.rstrip("/")
payload: dict[str, Any] = { payload: dict[str, Any] = {
"model": self.model_name(), "model": self.model_name(completion_request.preferred_model),
"messages": self._convert_messages(completion_request.messages), "messages": self._convert_messages(completion_request.messages),
} }
if completion_request.tools: if completion_request.tools:
@ -340,6 +366,7 @@ class ProviderRegistry:
{ {
"name": provider.name, "name": provider.name,
"model": provider.model_name(), "model": provider.model_name(),
"models": provider.list_models(),
"available": provider.is_available(), "available": provider.is_available(),
"reason": provider.unavailable_reason(), "reason": provider.unavailable_reason(),
"capabilities": { "capabilities": {
@ -400,13 +427,16 @@ class ModelRouter:
elif name == self.config.default_provider: elif name == self.config.default_provider:
selection_reason = "selected default provider" selection_reason = "selected default provider"
try: try:
payload = provider.complete(completion_request) provider_request = completion_request
if completion_request.preferred_provider and name != completion_request.preferred_provider:
provider_request = replace(completion_request, preferred_model=None)
payload = provider.complete(provider_request)
except ModelProviderError as exc: except ModelProviderError as exc:
reasons.append(f"{name}: request failed: {exc}") reasons.append(f"{name}: request failed: {exc}")
continue continue
return CompletionResponse( return CompletionResponse(
provider_name=name, provider_name=name,
model_name=provider.model_name(), model_name=provider.model_name(provider_request.preferred_model),
payload=payload, payload=payload,
selection_reason=selection_reason, selection_reason=selection_reason,
attempted=attempted, attempted=attempted,