telegram-cli-bot/bot/base_ai_provider.py

301 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Base AI Provider Protocol - универсальный интерфейс для всех AI-провайдеров.
Определяет общий протокол который должен реализовать каждый AI-провайдер
для работы с инструментами (tools).
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Callable, List, AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
class ToolCallStatus(Enum):
"""Статус выполнения инструмента."""
SUCCESS = "success"
ERROR = "error"
PENDING = "pending"
@dataclass
class ToolCall:
"""Вызов инструмента."""
tool_name: str
tool_args: Dict[str, Any]
tool_call_id: Optional[str] = None
status: ToolCallStatus = ToolCallStatus.PENDING
result: Optional[Any] = None
error: Optional[str] = None
@dataclass
class AIMessage:
"""Сообщение от AI-провайдера."""
content: str
tool_calls: List[ToolCall] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
is_streaming: bool = False
@dataclass
class ProviderResponse:
"""Ответ от AI-провайдера."""
success: bool
message: Optional[AIMessage] = None
error: Optional[str] = None
provider_name: str = ""
usage: Optional[Dict[str, Any]] = None
raw_response: Optional[Any] = None
class BaseAIProvider(ABC):
"""
Базовый класс для всех AI-провайдеров.
Каждый провайдер (Qwen, GigaChat, OpenAI, etc.) должен реализовать
этот интерфейс для поддержки инструментов и единого формата ответов.
"""
@property
@abstractmethod
def provider_name(self) -> str:
"""Название провайдера (например, 'Qwen Code', 'GigaChat')."""
pass
@property
@abstractmethod
def supports_tools(self) -> bool:
"""Поддерживает ли провайдер инструменты нативно."""
pass
@property
@abstractmethod
def supports_streaming(self) -> bool:
"""Поддерживает ли провайдер потоковый вывод."""
pass
@abstractmethod
async def chat(
self,
prompt: str,
system_prompt: Optional[str] = None,
context: Optional[List[Dict[str, str]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
on_chunk: Optional[Callable[[str], Any]] = None,
**kwargs
) -> ProviderResponse:
"""
Отправить запрос AI-провайдеру.
Args:
prompt: Запрос пользователя
system_prompt: Системный промпт
context: История диалога
tools: Доступные инструменты (схема)
on_chunk: Callback для потокового вывода
**kwargs: Дополнительные параметры
Returns:
ProviderResponse с ответом и возможными вызовами инструментов
"""
pass
@abstractmethod
async def execute_tool(
self,
tool_name: str,
tool_args: Dict[str, Any],
tool_call_id: Optional[str] = None,
**kwargs
) -> ToolCall:
"""
Выполнить инструмент (если провайдер поддерживает нативно).
Для провайдеров без нативной поддержки инструментов,
этот метод может быть заглушкой.
Args:
tool_name: Имя инструмента
tool_args: Аргументы инструмента
tool_call_id: ID вызова
Returns:
ToolCall с результатом выполнения
"""
pass
def is_available(self) -> bool:
"""Проверить доступность провайдера."""
return True
def get_tools_schema(self, tools_registry: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Получить схему инструментов для промпта.
По умолчанию возвращает описание всех доступных инструментов.
Провайдеры могут переопределить для кастомизации.
Args:
tools_registry: Словарь инструментов {name: tool_instance}
Returns:
Список схем инструментов
"""
schema = []
for name, tool in tools_registry.items():
if hasattr(tool, 'get_schema'):
schema.append(tool.get_schema())
elif hasattr(tool, 'description'):
schema.append({
"name": name,
"description": tool.description,
"parameters": getattr(tool, 'parameters', {})
})
return schema
async def process_with_tools(
self,
prompt: str,
system_prompt: Optional[str] = None,
context: Optional[List[Dict[str, str]]] = None,
tools_registry: Optional[Dict[str, Any]] = None,
on_chunk: Optional[Callable[[str], Any]] = None,
max_iterations: int = 5,
**kwargs
) -> ProviderResponse:
"""
Универсальный метод для обработки запросов с инструментами.
Реализует цикл:
1. Отправить запрос провайдеру
2. Если есть вызовы инструментов - выполнить их
3. Отправить результаты обратно провайдеру
4. Повторить пока не будет финального ответа
Args:
prompt: Запрос пользователя
system_prompt: Системный промпт
context: История диалога
tools_registry: Словарь инструментов
on_chunk: Callback для потокового вывода
max_iterations: Максимум итераций цикла
Returns:
ProviderResponse с финальным ответом
"""
if not tools_registry:
# Без инструментов - простой запрос
return await self.chat(
prompt=prompt,
system_prompt=system_prompt,
context=context,
on_chunk=on_chunk,
**kwargs
)
messages = []
if context:
messages.extend(context)
messages.append({"role": "user", "content": prompt})
tools_schema = self.get_tools_schema(tools_registry) if self.supports_tools else None
for iteration in range(max_iterations):
# Отправляем запрос провайдеру
response = await self.chat(
prompt=None, # Уже в messages
system_prompt=system_prompt,
context=messages if iteration == 0 else None,
tools=tools_schema,
on_chunk=on_chunk,
**kwargs
)
if not response.success:
return response
message = response.message
if not message:
return ProviderResponse(
success=False,
error="Пустой ответ от провайдера",
provider_name=self.provider_name
)
# Если нет вызовов инструментов - возвращаем ответ
if not message.tool_calls:
return response
# Выполняем инструменты
tool_results = []
for tool_call in message.tool_calls:
if tool_call.tool_name in tools_registry:
tool = tools_registry[tool_call.tool_name]
try:
if hasattr(tool, 'execute'):
result = await tool.execute(
**tool_call.tool_args,
user_id=kwargs.get('user_id')
)
elif hasattr(tool, '__call__'):
result = await tool(**tool_call.tool_args)
else:
result = f"Инструмент {tool_call.tool_name} не имеет метода execute"
tool_call.result = result
tool_call.status = ToolCallStatus.SUCCESS
except Exception as e:
tool_call.error = str(e)
tool_call.status = ToolCallStatus.ERROR
result = f"Ошибка: {e}"
tool_results.append({
"tool": tool_call.tool_name,
"args": tool_call.tool_args,
"result": result,
"status": tool_call.status.value
})
else:
tool_call.error = f"Инструмент {tool_call.tool_name} не найден"
tool_call.status = ToolCallStatus.ERROR
tool_results.append({
"tool": tool_call.tool_name,
"error": tool_call.error
})
# Добавляем результаты в контекст для следующей итерации
messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.tool_call_id,
"name": tc.tool_name,
"arguments": tc.tool_args
}
for tc in message.tool_calls
]
})
messages.append({
"role": "tool",
"content": str(tool_results)
})
# Обновляем системный промпт для следующей итерации
system_prompt = system_prompt or ""
# Достигли максимума итераций
return ProviderResponse(
success=True,
message=AIMessage(
content=message.content + "\n\n[Достигнут максимум итераций выполнения инструментов]",
metadata={"iterations": max_iterations}
),
provider_name=self.provider_name,
usage=response.usage
)