telegram-cli-bot/bot/base_ai_provider.py

348 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Base AI Provider Protocol - универсальный интерфейс для всех AI-провайдеров.
Определяет общий протокол который должен реализовать каждый AI-провайдер
для работы с инструментами (tools).
"""
import json
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Callable, List, AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
class ToolCallStatus(Enum):
"""Статус выполнения инструмента."""
SUCCESS = "success"
ERROR = "error"
PENDING = "pending"
@dataclass
class ToolCall:
"""Вызов инструмента."""
tool_name: str
tool_args: Dict[str, Any]
tool_call_id: Optional[str] = None
status: ToolCallStatus = ToolCallStatus.PENDING
result: Optional[Any] = None
error: Optional[str] = None
@dataclass
class AIMessage:
"""Сообщение от AI-провайдера."""
content: str
tool_calls: List[ToolCall] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
is_streaming: bool = False
@dataclass
class ProviderResponse:
"""Ответ от AI-провайдера."""
success: bool
message: Optional[AIMessage] = None
error: Optional[str] = None
provider_name: str = ""
usage: Optional[Dict[str, Any]] = None
raw_response: Optional[Any] = None
class BaseAIProvider(ABC):
"""
Базовый класс для всех AI-провайдеров.
Каждый провайдер (Qwen, GigaChat, OpenAI, etc.) должен реализовать
этот интерфейс для поддержки инструментов и единого формата ответов.
"""
@property
@abstractmethod
def provider_name(self) -> str:
"""Название провайдера (например, 'Qwen Code', 'GigaChat')."""
pass
@property
@abstractmethod
def supports_tools(self) -> bool:
"""Поддерживает ли провайдер инструменты нативно."""
pass
@property
@abstractmethod
def supports_streaming(self) -> bool:
"""Поддерживает ли провайдер потоковый вывод."""
pass
@abstractmethod
async def chat(
self,
prompt: str,
system_prompt: Optional[str] = None,
context: Optional[List[Dict[str, str]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
on_chunk: Optional[Callable[[str], Any]] = None,
**kwargs
) -> ProviderResponse:
"""
Отправить запрос AI-провайдеру.
Args:
prompt: Запрос пользователя
system_prompt: Системный промпт
context: История диалога
tools: Доступные инструменты (схема)
on_chunk: Callback для потокового вывода
**kwargs: Дополнительные параметры
Returns:
ProviderResponse с ответом и возможными вызовами инструментов
"""
pass
@abstractmethod
async def execute_tool(
self,
tool_name: str,
tool_args: Dict[str, Any],
tool_call_id: Optional[str] = None,
**kwargs
) -> ToolCall:
"""
Выполнить инструмент (если провайдер поддерживает нативно).
Для провайдеров без нативной поддержки инструментов,
этот метод может быть заглушкой.
Args:
tool_name: Имя инструмента
tool_args: Аргументы инструмента
tool_call_id: ID вызова
Returns:
ToolCall с результатом выполнения
"""
pass
def is_available(self) -> bool:
"""Проверить доступность провайдера."""
return True
def get_tools_schema(self, tools_registry: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Получить схему инструментов для промпта.
По умолчанию возвращает описание всех доступных инструментов.
Провайдеры могут переопределить для кастомизации.
Args:
tools_registry: Словарь инструментов {name: tool_instance} или объект реестра
Returns:
Список схем инструментов
"""
schema = []
# Обрабатываем разные типы tools_registry
if tools_registry is None:
return schema
# Если это ToolsRegistry с методом get_all()
if hasattr(tools_registry, 'get_all') and callable(getattr(tools_registry, 'get_all')):
items = tools_registry.get_all().items()
# Если это dict - используем .items()
elif isinstance(tools_registry, dict):
items = tools_registry.items()
# Если это объект с атрибутом tools
elif hasattr(tools_registry, 'tools'):
items = tools_registry.tools.items() if isinstance(tools_registry.tools, dict) else []
# Если это объект поддерживающий .items()
elif hasattr(tools_registry, 'items'):
items = tools_registry.items()
else:
logger.warning(f"Неизвестный тип tools_registry: {type(tools_registry)}")
return schema
for name, tool in items:
if hasattr(tool, 'get_schema'):
schema.append(tool.get_schema())
elif hasattr(tool, 'description'):
schema.append({
"name": name,
"description": tool.description,
"parameters": getattr(tool, 'parameters', {})
})
return schema
async def process_with_tools(
self,
prompt: str,
system_prompt: Optional[str] = None,
context: Optional[List[Dict[str, str]]] = None,
tools_registry: Optional[Dict[str, Any]] = None,
on_chunk: Optional[Callable[[str], Any]] = None,
max_iterations: int = 5,
**kwargs
) -> ProviderResponse:
"""
Универсальный метод для обработки запросов с инструментами.
Реализует цикл:
1. Отправить запрос провайдеру
2. Если есть вызовы инструментов - выполнить их
3. Отправить результаты обратно провайдеру
4. Повторить пока не будет финального ответа
Args:
prompt: Запрос пользователя
system_prompt: Системный промпт
context: История диалога
tools_registry: Словарь инструментов
on_chunk: Callback для потокового вывода
max_iterations: Максимум итераций цикла
Returns:
ProviderResponse с финальным ответом
"""
if not tools_registry:
# Без инструментов - простой запрос
return await self.chat(
prompt=prompt,
system_prompt=system_prompt,
context=context,
on_chunk=on_chunk,
**kwargs
)
# Формируем базовый контекст — БЕЗ system message
# System message будет передаваться отдельным параметром
base_messages = []
if context:
# Фильтруем system messages из context — они будут переданы через system_prompt
for msg in context:
if msg.get("role") != "system":
base_messages.append(msg)
base_messages.append({"role": "user", "content": prompt})
tools_schema = self.get_tools_schema(tools_registry) if self.supports_tools else None
# Копируем сообщения для каждой итерации
messages = base_messages.copy()
for iteration in range(max_iterations):
# Отправляем запрос провайдеру
# system_prompt передаётся всегда — провайдер сам решит как его использовать
response = await self.chat(
prompt=None, # Уже в messages
system_prompt=system_prompt,
context=messages,
tools=tools_schema,
on_chunk=on_chunk,
**kwargs
)
if not response.success:
return response
message = response.message
if not message:
return ProviderResponse(
success=False,
error="Пустой ответ от провайдера",
provider_name=self.provider_name
)
# Если нет вызовов инструментов - возвращаем ответ
if not message.tool_calls:
return response
# Выполняем инструменты
tool_results = []
for tool_call in message.tool_calls:
# Проверяем наличие инструмента через метод .get() для поддержки ToolsRegistry
if hasattr(tools_registry, 'get'):
tool = tools_registry.get(tool_call.tool_name)
elif isinstance(tools_registry, dict):
tool = tools_registry.get(tool_call.tool_name)
else:
tool = None
if tool is not None:
try:
if hasattr(tool, 'execute'):
result = await tool.execute(
**tool_call.tool_args,
user_id=kwargs.get('user_id')
)
elif hasattr(tool, '__call__'):
result = await tool(**tool_call.tool_args)
else:
result = f"Инструмент {tool_call.tool_name} не имеет метода execute"
tool_call.result = result
tool_call.status = ToolCallStatus.SUCCESS
except Exception as e:
tool_call.error = str(e)
tool_call.status = ToolCallStatus.ERROR
result = f"Ошибка: {e}"
# Преобразуем результат в JSON-сериализуемый формат
# ToolResult имеет метод to_dict(), строки оставляем как есть
if hasattr(result, 'to_dict'):
result_serializable = result.to_dict()
else:
result_serializable = result
tool_results.append({
"tool": tool_call.tool_name,
"args": tool_call.tool_args,
"result": result_serializable,
"status": tool_call.status.value
})
else:
tool_call.error = f"Инструмент {tool_call.tool_name} не найден"
tool_call.status = ToolCallStatus.ERROR
tool_results.append({
"tool": tool_call.tool_name,
"error": tool_call.error
})
# Добавляем результаты в контекст для следующей итерации
messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.tool_call_id,
"name": tc.tool_name,
"arguments": tc.tool_args
}
for tc in message.tool_calls
]
})
# GigaChat требует валидный JSON в tool messages, а не Python repr строку
# Используем json.dumps для корректного форматирования
messages.append({
"role": "tool",
"content": json.dumps(tool_results, ensure_ascii=False)
})
# Обновляем системный промпт для следующей итерации
system_prompt = system_prompt or ""
# Достигли максимума итераций
return ProviderResponse(
success=True,
message=AIMessage(
content=message.content + "\n\n[Достигнут максимум итераций выполнения инструментов]",
metadata={"iterations": max_iterations}
),
provider_name=self.provider_name,
usage=response.usage
)