Improve provider selection UX and fix GigaChat token refresh
This commit is contained in:
parent
b8c55a8194
commit
99d0be151f
305
bot/app.py
305
bot/app.py
|
|
@ -14,6 +14,16 @@ from telegram_api import TelegramAPI
|
||||||
|
|
||||||
STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-state.json"
|
STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-state.json"
|
||||||
TYPING_INTERVAL_SECONDS = 4.0
|
TYPING_INTERVAL_SECONDS = 4.0
|
||||||
|
BOT_COMMANDS = [
|
||||||
|
{"command": "help", "description": "Список команд"},
|
||||||
|
{"command": "status", "description": "Статус сервера и чата"},
|
||||||
|
{"command": "auth", "description": "Запустить Qwen OAuth"},
|
||||||
|
{"command": "provider", "description": "Выбрать провайдера"},
|
||||||
|
{"command": "model", "description": "Выбрать модель"},
|
||||||
|
{"command": "session", "description": "Показать сессию"},
|
||||||
|
{"command": "clear", "description": "Очистить контекст"},
|
||||||
|
{"command": "cancel", "description": "Отменить активный job"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_state() -> dict[str, Any]:
|
def load_state() -> dict[str, Any]:
|
||||||
|
|
@ -26,6 +36,7 @@ def load_state() -> dict[str, Any]:
|
||||||
"chat_active_jobs": {},
|
"chat_active_jobs": {},
|
||||||
"chat_queues": {},
|
"chat_queues": {},
|
||||||
"pending_approvals": {},
|
"pending_approvals": {},
|
||||||
|
"chat_preferences": {},
|
||||||
}
|
}
|
||||||
state = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
state = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
||||||
state.setdefault("sessions", {})
|
state.setdefault("sessions", {})
|
||||||
|
|
@ -34,6 +45,7 @@ def load_state() -> dict[str, Any]:
|
||||||
state.setdefault("chat_active_jobs", {})
|
state.setdefault("chat_active_jobs", {})
|
||||||
state.setdefault("chat_queues", {})
|
state.setdefault("chat_queues", {})
|
||||||
state.setdefault("pending_approvals", {})
|
state.setdefault("pending_approvals", {})
|
||||||
|
state.setdefault("chat_preferences", {})
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,6 +78,49 @@ def send_text_chunks(api: TelegramAPI, chat_id: int, text: str) -> None:
|
||||||
api.send_message(chat_id, normalized[start : start + chunk_size])
|
api.send_message(chat_id, normalized[start : start + chunk_size])
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_bot_commands(api: TelegramAPI, state: dict[str, Any]) -> None:
|
||||||
|
api.set_my_commands(BOT_COMMANDS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_catalog(config: BotConfig) -> dict[str, Any]:
|
||||||
|
return get_json(f"{config.server_url}/api/v1/provider-catalog")
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_preferences(state: dict[str, Any], chat_id: int) -> dict[str, Any]:
|
||||||
|
prefs = state.setdefault("chat_preferences", {}).setdefault(str(chat_id), {})
|
||||||
|
prefs.setdefault("provider", None)
|
||||||
|
prefs.setdefault("model", None)
|
||||||
|
return prefs
|
||||||
|
|
||||||
|
|
||||||
|
def get_selected_provider_and_model(
|
||||||
|
state: dict[str, Any],
|
||||||
|
chat_id: int,
|
||||||
|
catalog: dict[str, Any],
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
prefs = get_chat_preferences(state, chat_id)
|
||||||
|
provider = prefs.get("provider") or catalog.get("default_provider")
|
||||||
|
model = prefs.get("model")
|
||||||
|
providers = {item.get("name"): item for item in catalog.get("providers", [])}
|
||||||
|
provider_info = providers.get(provider or "")
|
||||||
|
available_models = provider_info.get("models", []) if provider_info else []
|
||||||
|
if not model and available_models:
|
||||||
|
model = available_models[0].get("id")
|
||||||
|
return provider, model
|
||||||
|
|
||||||
|
|
||||||
|
def format_status_text(job_state: dict[str, Any], text: str) -> str:
|
||||||
|
del job_state
|
||||||
|
return (text or "Пустой ответ.")[:4000]
|
||||||
|
|
||||||
|
|
||||||
|
def format_final_signature(job_state: dict[str, Any]) -> str:
|
||||||
|
model = job_state.get("model") if isinstance(job_state.get("model"), str) else None
|
||||||
|
if not model:
|
||||||
|
return ""
|
||||||
|
return f"\n\n<i>{html.escape(model)}</i>"
|
||||||
|
|
||||||
|
|
||||||
def render_markdownish_html(text: str) -> str:
|
def render_markdownish_html(text: str) -> str:
|
||||||
normalized = (text or "Пустой ответ.").replace("\r\n", "\n")
|
normalized = (text or "Пустой ответ.").replace("\r\n", "\n")
|
||||||
fence_pattern = re.compile(r"```(?:[^\n`]*)\n(.*?)```", re.DOTALL)
|
fence_pattern = re.compile(r"```(?:[^\n`]*)\n(.*?)```", re.DOTALL)
|
||||||
|
|
@ -91,7 +146,7 @@ def update_status_message(
|
||||||
job_state: dict[str, Any],
|
job_state: dict[str, Any],
|
||||||
text: str,
|
text: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
normalized = (text or "Пустой ответ.")[:4000]
|
normalized = format_status_text(job_state, text)
|
||||||
chat_id = int(job_state["chat_id"])
|
chat_id = int(job_state["chat_id"])
|
||||||
message_id = int(job_state.get("status_message_id") or 0)
|
message_id = int(job_state.get("status_message_id") or 0)
|
||||||
if message_id:
|
if message_id:
|
||||||
|
|
@ -100,6 +155,7 @@ def update_status_message(
|
||||||
message_id = api.send_message(chat_id, normalized)
|
message_id = api.send_message(chat_id, normalized)
|
||||||
job_state["status_message_id"] = message_id
|
job_state["status_message_id"] = message_id
|
||||||
job_state["status_message_text"] = normalized
|
job_state["status_message_text"] = normalized
|
||||||
|
job_state["status_body_text"] = text or "Пустой ответ."
|
||||||
|
|
||||||
|
|
||||||
def set_final_message(
|
def set_final_message(
|
||||||
|
|
@ -109,8 +165,9 @@ def set_final_message(
|
||||||
) -> None:
|
) -> None:
|
||||||
normalized = text or "Пустой ответ."
|
normalized = text or "Пустой ответ."
|
||||||
chunk_size = 3800
|
chunk_size = 3800
|
||||||
|
signature = format_final_signature(job_state)
|
||||||
first_chunk = normalized[:chunk_size]
|
first_chunk = normalized[:chunk_size]
|
||||||
rendered = render_markdownish_html(first_chunk)
|
rendered = render_markdownish_html(first_chunk) + signature
|
||||||
chat_id = int(job_state["chat_id"])
|
chat_id = int(job_state["chat_id"])
|
||||||
message_id = int(job_state.get("status_message_id") or 0)
|
message_id = int(job_state.get("status_message_id") or 0)
|
||||||
if message_id:
|
if message_id:
|
||||||
|
|
@ -145,30 +202,38 @@ def keep_job_typing(
|
||||||
def summarize_event(event: dict[str, Any]) -> str | None:
|
def summarize_event(event: dict[str, Any]) -> str | None:
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
if event_type == "job_status":
|
if event_type == "job_status":
|
||||||
return event.get("message")
|
message = str(event.get("message") or "").strip()
|
||||||
|
if message == "Запрос принят сервером":
|
||||||
|
return "Смотрю, что можно сделать."
|
||||||
|
if message == "Ответ готов":
|
||||||
|
return "Формулирую ответ."
|
||||||
|
return message or None
|
||||||
if event_type == "model_request":
|
if event_type == "model_request":
|
||||||
provider = event.get("provider")
|
return "Думаю над ответом."
|
||||||
model = event.get("model")
|
|
||||||
if provider and model:
|
|
||||||
return f"Думаю над ответом через {provider}/{model}"
|
|
||||||
if provider:
|
|
||||||
return f"Думаю над ответом через {provider}"
|
|
||||||
return "Думаю над ответом"
|
|
||||||
if event_type == "tool_call":
|
if event_type == "tool_call":
|
||||||
return f"Вызываю инструмент: {event.get('name')}"
|
tool_name = str(event.get("name") or "")
|
||||||
|
if tool_name in {"read_file", "list_files", "glob_search", "grep_text", "stat_path"}:
|
||||||
|
return "Просматриваю файлы и контекст."
|
||||||
|
if tool_name in {"git_status", "git_diff"}:
|
||||||
|
return "Проверяю изменения в проекте."
|
||||||
|
if tool_name in {"exec_command"}:
|
||||||
|
return "Проверяю окружение и команды."
|
||||||
|
if tool_name in {"web_search"}:
|
||||||
|
return "Ищу нужную информацию."
|
||||||
|
return "Проверяю детали."
|
||||||
if event_type == "tool_result":
|
if event_type == "tool_result":
|
||||||
result = event.get("result", {})
|
result = event.get("result", {})
|
||||||
if isinstance(result, dict) and "error" in result:
|
if isinstance(result, dict) and "error" in result:
|
||||||
return f"Инструмент {event.get('name')} завершился с ошибкой"
|
return "Наткнулся на проблему, перепроверяю."
|
||||||
return f"Инструмент {event.get('name')} завершён"
|
return "Собрал нужные данные."
|
||||||
if event_type == "error":
|
if event_type == "error":
|
||||||
return f"Ошибка: {event.get('message')}"
|
return f"Ошибка: {event.get('message')}"
|
||||||
if event_type == "approval_result":
|
if event_type == "approval_result":
|
||||||
status = event.get("status")
|
status = event.get("status")
|
||||||
tool_name = event.get("tool_name")
|
tool_name = event.get("tool_name")
|
||||||
if status == "approved":
|
if status == "approved":
|
||||||
return f"Подтверждение получено для {tool_name}"
|
return f"Получил подтверждение для {tool_name}."
|
||||||
return f"Подтверждение отклонено для {tool_name}"
|
return f"Подтверждение для {tool_name} отклонено."
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -183,6 +248,83 @@ def format_approval_keyboard(approval_id: str) -> dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_provider_keyboard(catalog: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
rows: list[list[dict[str, str]]] = []
|
||||||
|
row: list[dict[str, str]] = []
|
||||||
|
for item in catalog.get("providers", []):
|
||||||
|
name = str(item.get("name") or "")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
row.append({"text": name, "callback_data": f"provider:{name}"})
|
||||||
|
if len(row) == 2:
|
||||||
|
rows.append(row)
|
||||||
|
row = []
|
||||||
|
if row:
|
||||||
|
rows.append(row)
|
||||||
|
return {"inline_keyboard": rows}
|
||||||
|
|
||||||
|
|
||||||
|
def format_model_keyboard(provider: str, models: list[dict[str, str]]) -> dict[str, Any]:
|
||||||
|
rows = [
|
||||||
|
[{"text": str(item.get("name") or item.get("id") or "-"), "callback_data": f"model:{provider}:{item.get('id')}"}]
|
||||||
|
for item in models
|
||||||
|
if item.get("id")
|
||||||
|
]
|
||||||
|
return {"inline_keyboard": rows}
|
||||||
|
|
||||||
|
|
||||||
|
def set_chat_provider(
|
||||||
|
state: dict[str, Any],
|
||||||
|
chat_id: int,
|
||||||
|
provider: str | None,
|
||||||
|
*,
|
||||||
|
reset_model: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
prefs = get_chat_preferences(state, chat_id)
|
||||||
|
prefs["provider"] = provider
|
||||||
|
if reset_model:
|
||||||
|
prefs["model"] = None
|
||||||
|
return prefs
|
||||||
|
|
||||||
|
|
||||||
|
def set_chat_model(state: dict[str, Any], chat_id: int, model: str | None) -> dict[str, Any]:
|
||||||
|
prefs = get_chat_preferences(state, chat_id)
|
||||||
|
prefs["model"] = model
|
||||||
|
return prefs
|
||||||
|
|
||||||
|
|
||||||
|
def format_provider_selection(
|
||||||
|
catalog: dict[str, Any],
|
||||||
|
selected_provider: str | None,
|
||||||
|
) -> str:
|
||||||
|
lines = ["Выбери провайдера."]
|
||||||
|
for item in catalog.get("providers", []):
|
||||||
|
name = str(item.get("name") or "")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
marker = "•" if name == selected_provider else " "
|
||||||
|
availability = "ready" if item.get("available") else f"unavailable: {item.get('reason') or 'unknown'}"
|
||||||
|
lines.append(f"{marker} {name} ({availability})")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def format_model_selection(
|
||||||
|
provider: str,
|
||||||
|
models: list[dict[str, str]],
|
||||||
|
selected_model: str | None,
|
||||||
|
) -> str:
|
||||||
|
lines = [f"Выбери модель для {provider}."]
|
||||||
|
for item in models:
|
||||||
|
model_id = str(item.get("id") or "")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
marker = "•" if model_id == selected_model else " "
|
||||||
|
description = str(item.get("description") or "").strip()
|
||||||
|
suffix = f" ({description})" if description else ""
|
||||||
|
lines.append(f"{marker} {model_id}{suffix}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def format_approval_request(event: dict[str, Any]) -> str:
|
def format_approval_request(event: dict[str, Any]) -> str:
|
||||||
return (
|
return (
|
||||||
"Нужно подтверждение инструмента.\n"
|
"Нужно подтверждение инструмента.\n"
|
||||||
|
|
@ -290,7 +432,9 @@ def start_chat_job(
|
||||||
delayed: bool = False,
|
delayed: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
session_id = state.setdefault("sessions", {}).get(session_key)
|
session_id = state.setdefault("sessions", {}).get(session_key)
|
||||||
prefix = "Обрабатываю отложенный запрос..." if delayed else "Обрабатываю запрос..."
|
catalog = get_provider_catalog(config)
|
||||||
|
provider, model = get_selected_provider_and_model(state, chat_id, catalog)
|
||||||
|
prefix = "Возвращаюсь к вашему сообщению." if delayed else "Смотрю, что можно сделать."
|
||||||
status_message_id = api.send_message(chat_id, prefix)
|
status_message_id = api.send_message(chat_id, prefix)
|
||||||
start_result = post_json(
|
start_result = post_json(
|
||||||
f"{config.server_url}/api/v1/chat/start",
|
f"{config.server_url}/api/v1/chat/start",
|
||||||
|
|
@ -298,6 +442,8 @@ def start_chat_job(
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"message": text,
|
"message": text,
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
state["sessions"][session_key] = start_result["session_id"]
|
state["sessions"][session_key] = start_result["session_id"]
|
||||||
|
|
@ -310,6 +456,12 @@ def start_chat_job(
|
||||||
"seen_seq": 0,
|
"seen_seq": 0,
|
||||||
"status_message_id": status_message_id,
|
"status_message_id": status_message_id,
|
||||||
"status_message_text": prefix,
|
"status_message_text": prefix,
|
||||||
|
"status_body_text": prefix,
|
||||||
|
"requested_provider": provider,
|
||||||
|
"requested_model": model,
|
||||||
|
"provider": start_result.get("provider") or provider,
|
||||||
|
"model": start_result.get("model") or model,
|
||||||
|
"fallback_notified": False,
|
||||||
"last_typing_at": 0.0,
|
"last_typing_at": 0.0,
|
||||||
}
|
}
|
||||||
state.setdefault("chat_active_jobs", {})[str(chat_id)] = start_result["job_id"]
|
state.setdefault("chat_active_jobs", {})[str(chat_id)] = start_result["job_id"]
|
||||||
|
|
@ -501,8 +653,25 @@ def process_active_jobs(
|
||||||
reply_markup=format_approval_keyboard(str(event["approval_id"])),
|
reply_markup=format_approval_keyboard(str(event["approval_id"])),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
if event.get("type") == "model_request":
|
||||||
|
selection_reason = str(event.get("selection_reason") or "")
|
||||||
|
if (
|
||||||
|
"fell back to" in selection_reason
|
||||||
|
and not job_state.get("fallback_notified")
|
||||||
|
):
|
||||||
|
requested_provider = job_state.get("requested_provider") or "unknown"
|
||||||
|
api.send_message(
|
||||||
|
int(job_state["chat_id"]),
|
||||||
|
"Выбранный провайдер не ответил, поэтому продолжаю через запасной вариант.\n"
|
||||||
|
f"Изначально был выбран: {requested_provider}",
|
||||||
|
)
|
||||||
|
job_state["fallback_notified"] = True
|
||||||
|
if event.get("provider"):
|
||||||
|
job_state["provider"] = event.get("provider")
|
||||||
|
if event.get("model"):
|
||||||
|
job_state["model"] = event.get("model")
|
||||||
summary = summarize_event(event)
|
summary = summarize_event(event)
|
||||||
if summary and summary != job_state.get("status_message_text"):
|
if summary and summary != job_state.get("status_body_text"):
|
||||||
update_status_message(api, job_state, summary)
|
update_status_message(api, job_state, summary)
|
||||||
|
|
||||||
status = poll_result.get("status")
|
status = poll_result.get("status")
|
||||||
|
|
@ -590,7 +759,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
||||||
if text == "/start":
|
if text == "/start":
|
||||||
api.send_message(
|
api.send_message(
|
||||||
chat_id,
|
chat_id,
|
||||||
"new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /cancel, /clear, /approve, /reject.",
|
"new-qwen bot готов.\nКоманды доступны через Telegram menu: /help, /auth, /status, /provider, /model, /session, /cancel, /clear.",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -601,6 +770,8 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
||||||
"/auth - начать Qwen OAuth\n"
|
"/auth - начать Qwen OAuth\n"
|
||||||
"/auth_check [flow_id] - проверить авторизацию\n"
|
"/auth_check [flow_id] - проверить авторизацию\n"
|
||||||
"/status - статус OAuth и сервера\n"
|
"/status - статус OAuth и сервера\n"
|
||||||
|
"/provider [name] - выбрать провайдера\n"
|
||||||
|
"/model [id] - выбрать модель\n"
|
||||||
"/session - показать текущую сессию\n"
|
"/session - показать текущую сессию\n"
|
||||||
"/cancel - отменить активный запрос и очистить очередь\n"
|
"/cancel - отменить активный запрос и очистить очередь\n"
|
||||||
"/approve [approval_id] - подтвердить инструмент\n"
|
"/approve [approval_id] - подтвердить инструмент\n"
|
||||||
|
|
@ -652,10 +823,16 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
||||||
|
|
||||||
if text == "/status":
|
if text == "/status":
|
||||||
status = get_json(f"{config.server_url}/api/v1/auth/status")
|
status = get_json(f"{config.server_url}/api/v1/auth/status")
|
||||||
models_status = get_json(f"{config.server_url}/api/v1/models")
|
catalog = get_provider_catalog(config)
|
||||||
queue_size = len(state.setdefault("chat_queues", {}).get(str(chat_id), []))
|
queue_size = len(state.setdefault("chat_queues", {}).get(str(chat_id), []))
|
||||||
active_job = state.setdefault("chat_active_jobs", {}).get(str(chat_id))
|
active_job = state.setdefault("chat_active_jobs", {}).get(str(chat_id))
|
||||||
models_info = "\n".join([f" - {m['id']}: {m['name']} ({m['description']})" for m in models_status.get("models", [])]) or " Нет доступных моделей"
|
provider, model = get_selected_provider_and_model(state, chat_id, catalog)
|
||||||
|
providers_info = "\n".join(
|
||||||
|
[
|
||||||
|
f" - {item['name']}: model={item.get('model')}, available={item.get('available')}"
|
||||||
|
for item in catalog.get("providers", [])
|
||||||
|
]
|
||||||
|
) or " Нет доступных провайдеров"
|
||||||
send_text_chunks(
|
send_text_chunks(
|
||||||
api,
|
api,
|
||||||
chat_id,
|
chat_id,
|
||||||
|
|
@ -665,15 +842,61 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
||||||
f"available_providers: {', '.join(status.get('available_providers') or []) or '-'}\n"
|
f"available_providers: {', '.join(status.get('available_providers') or []) or '-'}\n"
|
||||||
f"default_provider: {status.get('default_provider')}\n"
|
f"default_provider: {status.get('default_provider')}\n"
|
||||||
f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n"
|
f"fallback_providers: {', '.join(status.get('fallback_providers') or []) or '-'}\n"
|
||||||
|
f"selected_provider: {provider}\n"
|
||||||
|
f"selected_model: {model}\n"
|
||||||
f"resource_url: {status.get('resource_url')}\n"
|
f"resource_url: {status.get('resource_url')}\n"
|
||||||
f"expires_at: {status.get('expires_at')}\n"
|
f"expires_at: {status.get('expires_at')}\n"
|
||||||
f"tool_policy: {status.get('tool_policy')}\n"
|
f"tool_policy: {status.get('tool_policy')}\n"
|
||||||
f"Доступные Qwen OAuth модели:\n{models_info}\n"
|
f"Провайдеры и модели:\n{providers_info}\n"
|
||||||
f"active_job: {active_job}\n"
|
f"active_job: {active_job}\n"
|
||||||
f"queued_messages: {queue_size}",
|
f"queued_messages: {queue_size}",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if text.startswith("/provider"):
|
||||||
|
catalog = get_provider_catalog(config)
|
||||||
|
parts = text.split(maxsplit=1)
|
||||||
|
if len(parts) == 1:
|
||||||
|
provider, _ = get_selected_provider_and_model(state, chat_id, catalog)
|
||||||
|
api.send_message(
|
||||||
|
chat_id,
|
||||||
|
format_provider_selection(catalog, provider),
|
||||||
|
reply_markup=format_provider_keyboard(catalog),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
provider_name = parts[1].strip()
|
||||||
|
provider_names = {str(item.get("name") or "") for item in catalog.get("providers", [])}
|
||||||
|
if provider_name not in provider_names:
|
||||||
|
api.send_message(chat_id, f"Неизвестный provider: {provider_name}")
|
||||||
|
return
|
||||||
|
set_chat_provider(state, chat_id, provider_name, reset_model=True)
|
||||||
|
_, model = get_selected_provider_and_model(state, chat_id, catalog)
|
||||||
|
api.send_message(chat_id, f"Выбран provider: {provider_name}\nmodel: {model or '-'}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if text.startswith("/model"):
|
||||||
|
catalog = get_provider_catalog(config)
|
||||||
|
provider, selected_model = get_selected_provider_and_model(state, chat_id, catalog)
|
||||||
|
provider_map = {str(item.get("name") or ""): item for item in catalog.get("providers", [])}
|
||||||
|
provider_info = provider_map.get(provider or "")
|
||||||
|
models = provider_info.get("models", []) if provider_info else []
|
||||||
|
parts = text.split(maxsplit=1)
|
||||||
|
if len(parts) == 1:
|
||||||
|
api.send_message(
|
||||||
|
chat_id,
|
||||||
|
format_model_selection(provider or "-", models, selected_model),
|
||||||
|
reply_markup=format_model_keyboard(provider or "-", models),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
model_id = parts[1].strip()
|
||||||
|
valid_model_ids = {str(item.get("id") or "") for item in models}
|
||||||
|
if model_id not in valid_model_ids:
|
||||||
|
api.send_message(chat_id, f"Неизвестная модель для {provider}: {model_id}")
|
||||||
|
return
|
||||||
|
set_chat_model(state, chat_id, model_id)
|
||||||
|
api.send_message(chat_id, f"Выбрана модель: {model_id}\nprovider: {provider}")
|
||||||
|
return
|
||||||
|
|
||||||
if text == "/cancel":
|
if text == "/cancel":
|
||||||
canceled = cancel_chat_work(
|
canceled = cancel_chat_work(
|
||||||
api,
|
api,
|
||||||
|
|
@ -750,6 +973,44 @@ def handle_callback_query(
|
||||||
|
|
||||||
action, approval_id = data.split(":", 1)
|
action, approval_id = data.split(":", 1)
|
||||||
if action not in {"approve", "reject"}:
|
if action not in {"approve", "reject"}:
|
||||||
|
if action == "provider":
|
||||||
|
catalog = get_provider_catalog(config)
|
||||||
|
provider_names = {str(item.get("name") or "") for item in catalog.get("providers", [])}
|
||||||
|
if approval_id not in provider_names:
|
||||||
|
api.answer_callback_query(callback_query_id, "Неизвестный провайдер")
|
||||||
|
return
|
||||||
|
set_chat_provider(state, chat_id, approval_id, reset_model=True)
|
||||||
|
provider, model = get_selected_provider_and_model(state, chat_id, catalog)
|
||||||
|
api.edit_message_text(
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
f"Выбран provider: {provider}\nmodel: {model or '-'}",
|
||||||
|
reply_markup={"inline_keyboard": []},
|
||||||
|
)
|
||||||
|
api.answer_callback_query(callback_query_id, "Провайдер переключен")
|
||||||
|
return
|
||||||
|
if action == "model":
|
||||||
|
provider, model_id = approval_id.split(":", 1) if ":" in approval_id else ("", "")
|
||||||
|
catalog = get_provider_catalog(config)
|
||||||
|
provider_map = {str(item.get("name") or ""): item for item in catalog.get("providers", [])}
|
||||||
|
provider_info = provider_map.get(provider)
|
||||||
|
model_ids = {
|
||||||
|
str(item.get("id") or "")
|
||||||
|
for item in (provider_info.get("models", []) if provider_info else [])
|
||||||
|
}
|
||||||
|
if model_id not in model_ids:
|
||||||
|
api.answer_callback_query(callback_query_id, "Неизвестная модель")
|
||||||
|
return
|
||||||
|
set_chat_provider(state, chat_id, provider, reset_model=False)
|
||||||
|
set_chat_model(state, chat_id, model_id)
|
||||||
|
api.edit_message_text(
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
f"Выбрана модель: {model_id}\nprovider: {provider}",
|
||||||
|
reply_markup={"inline_keyboard": []},
|
||||||
|
)
|
||||||
|
api.answer_callback_query(callback_query_id, "Модель переключена")
|
||||||
|
return
|
||||||
api.answer_callback_query(callback_query_id, "Неизвестное действие")
|
api.answer_callback_query(callback_query_id, "Неизвестное действие")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -782,6 +1043,8 @@ def main() -> None:
|
||||||
config = BotConfig.load()
|
config = BotConfig.load()
|
||||||
api = TelegramAPI(config.token, proxy_url=config.proxy_url)
|
api = TelegramAPI(config.token, proxy_url=config.proxy_url)
|
||||||
state = load_state()
|
state = load_state()
|
||||||
|
ensure_bot_commands(api, state)
|
||||||
|
save_state(state)
|
||||||
print("new-qwen bot polling started")
|
print("new-qwen bot polling started")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -119,3 +119,6 @@ class TelegramAPI:
|
||||||
if text:
|
if text:
|
||||||
payload["text"] = text
|
payload["text"] = text
|
||||||
self._post("answerCallbackQuery", payload)
|
self._post("answerCallbackQuery", payload)
|
||||||
|
|
||||||
|
def set_my_commands(self, commands: list[dict[str, str]]) -> None:
|
||||||
|
self._post("setMyCommands", {"commands": commands})
|
||||||
|
|
|
||||||
18
serv/app.py
18
serv/app.py
|
|
@ -139,6 +139,13 @@ class AppState:
|
||||||
"pending_approvals": len(self.approvals.list_pending()),
|
"pending_approvals": len(self.approvals.list_pending()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def provider_catalog(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"default_provider": self.config.default_provider,
|
||||||
|
"default_model": self.config.model,
|
||||||
|
"providers": self.providers.statuses(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(BaseHTTPRequestHandler):
|
class RequestHandler(BaseHTTPRequestHandler):
|
||||||
server_version = "new-qwen-serv/0.1"
|
server_version = "new-qwen-serv/0.1"
|
||||||
|
|
@ -180,6 +187,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
models = self.app.oauth.get_available_models()
|
models = self.app.oauth.get_available_models()
|
||||||
self._send(HTTPStatus.OK, {"models": models})
|
self._send(HTTPStatus.OK, {"models": models})
|
||||||
return
|
return
|
||||||
|
if self.path == "/api/v1/provider-catalog":
|
||||||
|
self._send(HTTPStatus.OK, self.app.provider_catalog())
|
||||||
|
return
|
||||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||||
|
|
||||||
def _run_chat_job(
|
def _run_chat_job(
|
||||||
|
|
@ -189,6 +199,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
preferred_provider: str | None = None,
|
preferred_provider: str | None = None,
|
||||||
|
preferred_model: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
if self.app.jobs.is_cancel_requested(job_id):
|
if self.app.jobs.is_cancel_requested(job_id):
|
||||||
|
|
@ -217,6 +228,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
),
|
),
|
||||||
is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id),
|
is_cancelled=lambda: self.app.jobs.is_cancel_requested(job_id),
|
||||||
preferred_provider=preferred_provider,
|
preferred_provider=preferred_provider,
|
||||||
|
preferred_model=preferred_model,
|
||||||
)
|
)
|
||||||
if self.app.jobs.is_cancel_requested(job_id):
|
if self.app.jobs.is_cancel_requested(job_id):
|
||||||
reason = "Job canceled by operator"
|
reason = "Job canceled by operator"
|
||||||
|
|
@ -338,12 +350,14 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
user_id = str(body.get("user_id") or "anonymous")
|
user_id = str(body.get("user_id") or "anonymous")
|
||||||
message = body["message"]
|
message = body["message"]
|
||||||
preferred_provider = body.get("provider")
|
preferred_provider = body.get("provider")
|
||||||
|
preferred_model = body.get("model")
|
||||||
session = self.app.sessions.load(session_id)
|
session = self.app.sessions.load(session_id)
|
||||||
history = session.get("messages", [])
|
history = session.get("messages", [])
|
||||||
result = self.app.agent.run(
|
result = self.app.agent.run(
|
||||||
history,
|
history,
|
||||||
message,
|
message,
|
||||||
preferred_provider=preferred_provider,
|
preferred_provider=preferred_provider,
|
||||||
|
preferred_model=preferred_model,
|
||||||
)
|
)
|
||||||
persisted_messages = result["messages"][1:]
|
persisted_messages = result["messages"][1:]
|
||||||
self.app.sessions.save(
|
self.app.sessions.save(
|
||||||
|
|
@ -375,10 +389,11 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
user_id = str(body.get("user_id") or "anonymous")
|
user_id = str(body.get("user_id") or "anonymous")
|
||||||
message = body["message"]
|
message = body["message"]
|
||||||
preferred_provider = body.get("provider")
|
preferred_provider = body.get("provider")
|
||||||
|
preferred_model = body.get("model")
|
||||||
job = self.app.jobs.create(session_id, user_id, message)
|
job = self.app.jobs.create(session_id, user_id, message)
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self._run_chat_job,
|
target=self._run_chat_job,
|
||||||
args=(job["job_id"], session_id, user_id, message, preferred_provider),
|
args=(job["job_id"], session_id, user_id, message, preferred_provider, preferred_model),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
@ -389,6 +404,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"status": "queued",
|
"status": "queued",
|
||||||
"provider": preferred_provider or self.app.config.default_provider,
|
"provider": preferred_provider or self.app.config.default_provider,
|
||||||
|
"model": preferred_model or self.app.config.model,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,14 @@ class GigaChatAuthManager:
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_expires_at(value: Any) -> int:
|
||||||
|
expires_at = int(value or 0)
|
||||||
|
# Sber can return unix time in milliseconds.
|
||||||
|
if expires_at > 10_000_000_000:
|
||||||
|
return expires_at // 1000
|
||||||
|
return expires_at
|
||||||
|
|
||||||
def fetch_token(self) -> dict[str, Any]:
|
def fetch_token(self) -> dict[str, Any]:
|
||||||
data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8")
|
data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8")
|
||||||
req = request.Request(
|
req = request.Request(
|
||||||
|
|
@ -76,7 +84,7 @@ class GigaChatAuthManager:
|
||||||
raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc
|
raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc
|
||||||
token = {
|
token = {
|
||||||
"access_token": payload["access_token"],
|
"access_token": payload["access_token"],
|
||||||
"expires_at": int(payload["expires_at"]),
|
"expires_at": self._normalize_expires_at(payload.get("expires_at")),
|
||||||
}
|
}
|
||||||
self.save_token(token)
|
self.save_token(token)
|
||||||
return token
|
return token
|
||||||
|
|
@ -86,7 +94,8 @@ class GigaChatAuthManager:
|
||||||
raise GigaChatError("GigaChat auth key is not configured")
|
raise GigaChatError("GigaChat auth key is not configured")
|
||||||
token = self.load_token()
|
token = self.load_token()
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
if token and int(token.get("expires_at", 0)) - now > 30:
|
expires_at = self._normalize_expires_at(token.get("expires_at", 0)) if token else 0
|
||||||
|
if token and expires_at - now > 30:
|
||||||
return str(token["access_token"])
|
return str(token["access_token"])
|
||||||
refreshed = self.fetch_token()
|
refreshed = self.fetch_token()
|
||||||
return str(refreshed["access_token"])
|
return str(refreshed["access_token"])
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ class QwenAgent:
|
||||||
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
|
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
|
||||||
is_cancelled: Callable[[], bool] | None = None,
|
is_cancelled: Callable[[], bool] | None = None,
|
||||||
preferred_provider: str | None = None,
|
preferred_provider: str | None = None,
|
||||||
|
preferred_model: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
emit = on_event or (lambda _event: None)
|
emit = on_event or (lambda _event: None)
|
||||||
cancel_check = is_cancelled or (lambda: False)
|
cancel_check = is_cancelled or (lambda: False)
|
||||||
|
|
@ -54,6 +55,7 @@ class QwenAgent:
|
||||||
tools=self.tools.schemas(),
|
tools=self.tools.schemas(),
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
preferred_provider=preferred_provider,
|
preferred_provider=preferred_provider,
|
||||||
|
preferred_model=preferred_model,
|
||||||
require_tools=True,
|
require_tools=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import replace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -28,6 +29,7 @@ class CompletionRequest:
|
||||||
tools: list[dict[str, Any]]
|
tools: list[dict[str, Any]]
|
||||||
tool_choice: str = "auto"
|
tool_choice: str = "auto"
|
||||||
preferred_provider: str | None = None
|
preferred_provider: str | None = None
|
||||||
|
preferred_model: str | None = None
|
||||||
require_tools: bool = True
|
require_tools: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,9 +55,14 @@ class BaseModelProvider:
|
||||||
def unavailable_reason(self) -> str | None:
|
def unavailable_reason(self) -> str | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def model_name(self) -> str:
|
def model_name(self, preferred_model: str | None = None) -> str:
|
||||||
|
del preferred_model
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def list_models(self) -> list[dict[str, str]]:
|
||||||
|
model = self.model_name()
|
||||||
|
return [{"id": model, "name": model, "description": ""}]
|
||||||
|
|
||||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -83,7 +90,8 @@ class UnavailableModelProvider(BaseModelProvider):
|
||||||
def unavailable_reason(self) -> str | None:
|
def unavailable_reason(self) -> str | None:
|
||||||
return self._reason
|
return self._reason
|
||||||
|
|
||||||
def model_name(self) -> str:
|
def model_name(self, preferred_model: str | None = None) -> str:
|
||||||
|
del preferred_model
|
||||||
return self._model_name
|
return self._model_name
|
||||||
|
|
||||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||||
|
|
@ -114,8 +122,15 @@ class QwenModelProvider(BaseModelProvider):
|
||||||
return None
|
return None
|
||||||
return "Qwen OAuth is not configured"
|
return "Qwen OAuth is not configured"
|
||||||
|
|
||||||
def model_name(self) -> str:
|
def model_name(self, preferred_model: str | None = None) -> str:
|
||||||
return self._model_name
|
if preferred_model and preferred_model not in QWEN_OAUTH_ALLOWED_MODELS:
|
||||||
|
raise ModelProviderError(
|
||||||
|
f"Provider {self.name} does not support model '{preferred_model}'"
|
||||||
|
)
|
||||||
|
return self.oauth.get_model_name_for_id(preferred_model or self._model_id)
|
||||||
|
|
||||||
|
def list_models(self) -> list[dict[str, str]]:
|
||||||
|
return self.oauth.get_available_models()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_content(value: Any) -> list[dict[str, str]]:
|
def _normalize_content(value: Any) -> list[dict[str, str]]:
|
||||||
|
|
@ -140,8 +155,9 @@ class QwenModelProvider(BaseModelProvider):
|
||||||
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
def complete(self, completion_request: CompletionRequest) -> dict[str, Any]:
|
||||||
creds = self.oauth.get_valid_credentials()
|
creds = self.oauth.get_valid_credentials()
|
||||||
base_url = self.oauth.get_openai_base_url(creds)
|
base_url = self.oauth.get_openai_base_url(creds)
|
||||||
|
model_name = self.model_name(completion_request.preferred_model)
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model_name(),
|
"model": model_name,
|
||||||
"messages": self._normalize_messages(completion_request.messages),
|
"messages": self._normalize_messages(completion_request.messages),
|
||||||
"max_tokens": 8000,
|
"max_tokens": 8000,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
@ -203,8 +219,18 @@ class GigaChatModelProvider(BaseModelProvider):
|
||||||
return None
|
return None
|
||||||
return "GigaChat auth key is not configured"
|
return "GigaChat auth key is not configured"
|
||||||
|
|
||||||
def model_name(self) -> str:
|
def model_name(self, preferred_model: str | None = None) -> str:
|
||||||
return self.config.gigachat_model
|
return (preferred_model or self.config.gigachat_model).strip() or self.config.gigachat_model
|
||||||
|
|
||||||
|
def list_models(self) -> list[dict[str, str]]:
|
||||||
|
model = self.config.gigachat_model
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": model,
|
||||||
|
"name": model,
|
||||||
|
"description": "Configured GigaChat model",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
|
def _convert_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
|
@ -294,7 +320,7 @@ class GigaChatModelProvider(BaseModelProvider):
|
||||||
raise ModelProviderError(str(exc)) from exc
|
raise ModelProviderError(str(exc)) from exc
|
||||||
api_base = self.config.gigachat_api_base_url.rstrip("/")
|
api_base = self.config.gigachat_api_base_url.rstrip("/")
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"model": self.model_name(),
|
"model": self.model_name(completion_request.preferred_model),
|
||||||
"messages": self._convert_messages(completion_request.messages),
|
"messages": self._convert_messages(completion_request.messages),
|
||||||
}
|
}
|
||||||
if completion_request.tools:
|
if completion_request.tools:
|
||||||
|
|
@ -340,6 +366,7 @@ class ProviderRegistry:
|
||||||
{
|
{
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"model": provider.model_name(),
|
"model": provider.model_name(),
|
||||||
|
"models": provider.list_models(),
|
||||||
"available": provider.is_available(),
|
"available": provider.is_available(),
|
||||||
"reason": provider.unavailable_reason(),
|
"reason": provider.unavailable_reason(),
|
||||||
"capabilities": {
|
"capabilities": {
|
||||||
|
|
@ -400,13 +427,16 @@ class ModelRouter:
|
||||||
elif name == self.config.default_provider:
|
elif name == self.config.default_provider:
|
||||||
selection_reason = "selected default provider"
|
selection_reason = "selected default provider"
|
||||||
try:
|
try:
|
||||||
payload = provider.complete(completion_request)
|
provider_request = completion_request
|
||||||
|
if completion_request.preferred_provider and name != completion_request.preferred_provider:
|
||||||
|
provider_request = replace(completion_request, preferred_model=None)
|
||||||
|
payload = provider.complete(provider_request)
|
||||||
except ModelProviderError as exc:
|
except ModelProviderError as exc:
|
||||||
reasons.append(f"{name}: request failed: {exc}")
|
reasons.append(f"{name}: request failed: {exc}")
|
||||||
continue
|
continue
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
provider_name=name,
|
provider_name=name,
|
||||||
model_name=provider.model_name(),
|
model_name=provider.model_name(provider_request.preferred_model),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
selection_reason=selection_reason,
|
selection_reason=selection_reason,
|
||||||
attempted=attempted,
|
attempted=attempted,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue