258 lines
10 KiB
Python
258 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
AI Provider Manager - управление переключением между AI-провайдерами.
|
||
|
||
Поддерживаемые провайдеры:
|
||
- qwen: Qwen Code CLI (основной)
|
||
- gigachat: GigaChat API (Сбер)
|
||
"""
|
||
|
||
import logging
|
||
from typing import Optional, Dict, Any, Callable, List
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AIProvider(Enum):
|
||
"""Доступные AI-провайдеры."""
|
||
QWEN = "qwen"
|
||
GIGACHAT = "gigachat"
|
||
|
||
|
||
@dataclass
|
||
class ProviderInfo:
|
||
"""Информация о провайдере."""
|
||
id: str
|
||
name: str
|
||
description: str
|
||
available: bool
|
||
is_active: bool
|
||
|
||
|
||
class AIProviderManager:
|
||
"""
|
||
Менеджер управления AI-провайдерами.
|
||
|
||
Позволяет переключаться между провайдерами и выполнять запросы
|
||
через активного провайдера.
|
||
"""
|
||
|
||
def __init__(self, qwen_manager=None, gigachat_provider=None):
|
||
self._qwen_manager = qwen_manager
|
||
self._gigachat_provider = gigachat_provider
|
||
self._provider_status: Dict[str, bool] = {}
|
||
|
||
# Проверяем доступность провайдеров при инициализации
|
||
self._check_provider_status()
|
||
|
||
def _check_provider_status(self):
|
||
"""Проверка доступности провайдеров."""
|
||
# Проверяем Qwen
|
||
self._provider_status[AIProvider.QWEN.value] = True # Qwen всегда доступен
|
||
|
||
# Проверяем GigaChat
|
||
if self._gigachat_provider:
|
||
self._provider_status[AIProvider.GIGACHAT.value] = self._gigachat_provider.is_available()
|
||
else:
|
||
self._provider_status[AIProvider.GIGACHAT.value] = False
|
||
|
||
def get_available_providers(self) -> List[str]:
|
||
"""Получить список доступных провайдеров."""
|
||
return [
|
||
provider_id
|
||
for provider_id, available in self._provider_status.items()
|
||
if available
|
||
]
|
||
|
||
def is_provider_available(self, provider_id: str) -> bool:
|
||
"""Проверить доступен ли провайдер."""
|
||
return self._provider_status.get(provider_id, False)
|
||
|
||
def get_provider_info(self, provider_id: str, is_active: bool = False) -> ProviderInfo:
|
||
"""Получить информацию о провайдере."""
|
||
providers = {
|
||
AIProvider.QWEN.value: ProviderInfo(
|
||
id=AIProvider.QWEN.value,
|
||
name="Qwen Code",
|
||
description="Alibaba Qwen Code CLI — мощный AI-ассистент с поддержкой инструментов",
|
||
available=self.is_provider_available(AIProvider.QWEN.value),
|
||
is_active=is_active
|
||
),
|
||
AIProvider.GIGACHAT.value: ProviderInfo(
|
||
id=AIProvider.GIGACHAT.value,
|
||
name="GigaChat",
|
||
description="Sber GigaChat API — российская AI-модель от Сбера",
|
||
available=self.is_provider_available(AIProvider.GIGACHAT.value),
|
||
is_active=is_active
|
||
)
|
||
}
|
||
return providers.get(provider_id)
|
||
|
||
def get_all_providers_info(self, active_provider_id: str) -> List[ProviderInfo]:
|
||
"""Получить информацию обо всех провайдерах."""
|
||
return [
|
||
self.get_provider_info(AIProvider.QWEN.value, AIProvider.QWEN.value == active_provider_id),
|
||
self.get_provider_info(AIProvider.GIGACHAT.value, AIProvider.GIGACHAT.value == active_provider_id)
|
||
]
|
||
|
||
def switch_provider(self, user_id: int, provider_id: str, state_manager) -> tuple[bool, str]:
|
||
"""
|
||
Переключить AI-провайдер для пользователя.
|
||
|
||
Args:
|
||
user_id: ID пользователя
|
||
provider_id: ID провайдера ("qwen" или "gigachat")
|
||
state_manager: Менеджер состояний для обновления состояния пользователя
|
||
|
||
Returns:
|
||
(success: bool, message: str)
|
||
"""
|
||
if not self.is_provider_available(provider_id):
|
||
return False, f"❌ Провайдер {provider_id} недоступен"
|
||
|
||
state = state_manager.get(user_id)
|
||
state.current_ai_provider = provider_id
|
||
|
||
provider_info = self.get_provider_info(provider_id)
|
||
|
||
logger.info(f"Пользователь {user_id} переключен на {provider_id}")
|
||
|
||
return True, f"✅ Переключен на {provider_info.name}"
|
||
|
||
def get_current_provider(self, state) -> str:
|
||
"""Получить текущего провайдера пользователя."""
|
||
return state.current_ai_provider
|
||
|
||
async def execute_request(
|
||
self,
|
||
provider_id: str,
|
||
user_id: int,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
on_output: Optional[Callable[[str], Any]] = None,
|
||
on_chunk: Optional[Callable[[str], Any]] = None,
|
||
on_event: Optional[Callable[[Any], Any]] = None,
|
||
context: Optional[Dict] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Выполнить запрос через указанного провайдера.
|
||
|
||
Args:
|
||
provider_id: ID провайдера
|
||
user_id: ID пользователя
|
||
prompt: Запрос
|
||
system_prompt: Системный промпт
|
||
on_output: Callback для вывода
|
||
on_chunk: Callback для потокового вывода
|
||
on_event: Callback для событий
|
||
context: Дополнительный контекст
|
||
|
||
Returns:
|
||
Dict с результатом:
|
||
- success: bool
|
||
- content: str
|
||
- error: str (если ошибка)
|
||
- provider: str
|
||
- metadata: dict
|
||
"""
|
||
try:
|
||
if provider_id == AIProvider.QWEN.value:
|
||
if not self._qwen_manager:
|
||
return {
|
||
"success": False,
|
||
"error": "Qwen менеджер не инициализирован",
|
||
"provider": provider_id
|
||
}
|
||
|
||
# Выполняем через Qwen
|
||
result = await self._qwen_manager.run_task(
|
||
user_id=user_id,
|
||
task=prompt,
|
||
on_output=on_output or (lambda x: None),
|
||
on_oauth_url=lambda x: None,
|
||
use_system_prompt=False,
|
||
on_chunk=on_chunk,
|
||
on_event=on_event
|
||
)
|
||
|
||
# Извлекаем текст из результата
|
||
import re
|
||
text_matches = re.findall(r'"text":"([^"]+)"', result)
|
||
content = " ".join(text_matches).replace("\\n", "\n") if text_matches else result
|
||
|
||
return {
|
||
"success": True,
|
||
"content": content,
|
||
"provider": provider_id,
|
||
"metadata": {"raw_result": result}
|
||
}
|
||
|
||
elif provider_id == AIProvider.GIGACHAT.value:
|
||
if not self._gigachat_provider:
|
||
return {
|
||
"success": False,
|
||
"error": "GigaChat провайдер не инициализирован",
|
||
"provider": provider_id
|
||
}
|
||
|
||
# Выполняем через GigaChat
|
||
result = await self._gigachat_provider.chat(
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
on_chunk=on_chunk
|
||
)
|
||
|
||
if result.get("success"):
|
||
return {
|
||
"success": True,
|
||
"content": result.get("content", ""),
|
||
"provider": provider_id,
|
||
"metadata": {
|
||
"model": result.get("model", "GigaChat-Pro"),
|
||
"usage": result.get("usage", {})
|
||
}
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"error": result.get("error", "Неизвестная ошибка GigaChat"),
|
||
"provider": provider_id
|
||
}
|
||
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"error": f"Неизвестный провайдер: {provider_id}",
|
||
"provider": provider_id
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Ошибка выполнения запроса через {provider_id}: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"provider": provider_id
|
||
}
|
||
|
||
|
||
# Глобальный менеджер (будет инициализирован в bot.py)
|
||
ai_provider_manager: Optional[AIProviderManager] = None
|
||
|
||
|
||
def init_ai_provider_manager(qwen_manager, gigachat_provider) -> AIProviderManager:
|
||
"""Инициализировать глобальный AIProviderManager."""
|
||
global ai_provider_manager
|
||
ai_provider_manager = AIProviderManager(qwen_manager, gigachat_provider)
|
||
logger.info(f"AIProviderManager инициализирован. Доступные провайдеры: {ai_provider_manager.get_available_providers()}")
|
||
return ai_provider_manager
|
||
|
||
|
||
def get_ai_provider_manager() -> AIProviderManager:
|
||
"""Получить глобальный AIProviderManager."""
|
||
global ai_provider_manager
|
||
if ai_provider_manager is None:
|
||
raise RuntimeError("AIProviderManager не инициализирован. Вызовите init_ai_provider_manager().")
|
||
return ai_provider_manager
|