348 lines
14 KiB
Python
348 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Base AI Provider Protocol - универсальный интерфейс для всех AI-провайдеров.
|
||
|
||
Определяет общий протокол который должен реализовать каждый AI-провайдер
|
||
для работы с инструментами (tools).
|
||
"""
|
||
|
||
import json
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional, Dict, Any, Callable, List, AsyncGenerator
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
|
||
|
||
class ToolCallStatus(Enum):
|
||
"""Статус выполнения инструмента."""
|
||
SUCCESS = "success"
|
||
ERROR = "error"
|
||
PENDING = "pending"
|
||
|
||
|
||
@dataclass
|
||
class ToolCall:
|
||
"""Вызов инструмента."""
|
||
tool_name: str
|
||
tool_args: Dict[str, Any]
|
||
tool_call_id: Optional[str] = None
|
||
status: ToolCallStatus = ToolCallStatus.PENDING
|
||
result: Optional[Any] = None
|
||
error: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class AIMessage:
|
||
"""Сообщение от AI-провайдера."""
|
||
content: str
|
||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
is_streaming: bool = False
|
||
|
||
|
||
@dataclass
|
||
class ProviderResponse:
|
||
"""Ответ от AI-провайдера."""
|
||
success: bool
|
||
message: Optional[AIMessage] = None
|
||
error: Optional[str] = None
|
||
provider_name: str = ""
|
||
usage: Optional[Dict[str, Any]] = None
|
||
raw_response: Optional[Any] = None
|
||
|
||
|
||
class BaseAIProvider(ABC):
|
||
"""
|
||
Базовый класс для всех AI-провайдеров.
|
||
|
||
Каждый провайдер (Qwen, GigaChat, OpenAI, etc.) должен реализовать
|
||
этот интерфейс для поддержки инструментов и единого формата ответов.
|
||
"""
|
||
|
||
@property
|
||
@abstractmethod
|
||
def provider_name(self) -> str:
|
||
"""Название провайдера (например, 'Qwen Code', 'GigaChat')."""
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def supports_tools(self) -> bool:
|
||
"""Поддерживает ли провайдер инструменты нативно."""
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def supports_streaming(self) -> bool:
|
||
"""Поддерживает ли провайдер потоковый вывод."""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def chat(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
context: Optional[List[Dict[str, str]]] = None,
|
||
tools: Optional[List[Dict[str, Any]]] = None,
|
||
on_chunk: Optional[Callable[[str], Any]] = None,
|
||
**kwargs
|
||
) -> ProviderResponse:
|
||
"""
|
||
Отправить запрос AI-провайдеру.
|
||
|
||
Args:
|
||
prompt: Запрос пользователя
|
||
system_prompt: Системный промпт
|
||
context: История диалога
|
||
tools: Доступные инструменты (схема)
|
||
on_chunk: Callback для потокового вывода
|
||
**kwargs: Дополнительные параметры
|
||
|
||
Returns:
|
||
ProviderResponse с ответом и возможными вызовами инструментов
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def execute_tool(
|
||
self,
|
||
tool_name: str,
|
||
tool_args: Dict[str, Any],
|
||
tool_call_id: Optional[str] = None,
|
||
**kwargs
|
||
) -> ToolCall:
|
||
"""
|
||
Выполнить инструмент (если провайдер поддерживает нативно).
|
||
|
||
Для провайдеров без нативной поддержки инструментов,
|
||
этот метод может быть заглушкой.
|
||
|
||
Args:
|
||
tool_name: Имя инструмента
|
||
tool_args: Аргументы инструмента
|
||
tool_call_id: ID вызова
|
||
|
||
Returns:
|
||
ToolCall с результатом выполнения
|
||
"""
|
||
pass
|
||
|
||
def is_available(self) -> bool:
|
||
"""Проверить доступность провайдера."""
|
||
return True
|
||
|
||
def get_tools_schema(self, tools_registry: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""
|
||
Получить схему инструментов для промпта.
|
||
|
||
По умолчанию возвращает описание всех доступных инструментов.
|
||
Провайдеры могут переопределить для кастомизации.
|
||
|
||
Args:
|
||
tools_registry: Словарь инструментов {name: tool_instance} или объект реестра
|
||
|
||
Returns:
|
||
Список схем инструментов
|
||
"""
|
||
schema = []
|
||
|
||
# Обрабатываем разные типы tools_registry
|
||
if tools_registry is None:
|
||
return schema
|
||
|
||
# Если это ToolsRegistry с методом get_all()
|
||
if hasattr(tools_registry, 'get_all') and callable(getattr(tools_registry, 'get_all')):
|
||
items = tools_registry.get_all().items()
|
||
# Если это dict - используем .items()
|
||
elif isinstance(tools_registry, dict):
|
||
items = tools_registry.items()
|
||
# Если это объект с атрибутом tools
|
||
elif hasattr(tools_registry, 'tools'):
|
||
items = tools_registry.tools.items() if isinstance(tools_registry.tools, dict) else []
|
||
# Если это объект поддерживающий .items()
|
||
elif hasattr(tools_registry, 'items'):
|
||
items = tools_registry.items()
|
||
else:
|
||
logger.warning(f"Неизвестный тип tools_registry: {type(tools_registry)}")
|
||
return schema
|
||
|
||
for name, tool in items:
|
||
if hasattr(tool, 'get_schema'):
|
||
schema.append(tool.get_schema())
|
||
elif hasattr(tool, 'description'):
|
||
schema.append({
|
||
"name": name,
|
||
"description": tool.description,
|
||
"parameters": getattr(tool, 'parameters', {})
|
||
})
|
||
return schema
|
||
|
||
async def process_with_tools(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
context: Optional[List[Dict[str, str]]] = None,
|
||
tools_registry: Optional[Dict[str, Any]] = None,
|
||
on_chunk: Optional[Callable[[str], Any]] = None,
|
||
max_iterations: int = 5,
|
||
**kwargs
|
||
) -> ProviderResponse:
|
||
"""
|
||
Универсальный метод для обработки запросов с инструментами.
|
||
|
||
Реализует цикл:
|
||
1. Отправить запрос провайдеру
|
||
2. Если есть вызовы инструментов - выполнить их
|
||
3. Отправить результаты обратно провайдеру
|
||
4. Повторить пока не будет финального ответа
|
||
|
||
Args:
|
||
prompt: Запрос пользователя
|
||
system_prompt: Системный промпт
|
||
context: История диалога
|
||
tools_registry: Словарь инструментов
|
||
on_chunk: Callback для потокового вывода
|
||
max_iterations: Максимум итераций цикла
|
||
|
||
Returns:
|
||
ProviderResponse с финальным ответом
|
||
"""
|
||
if not tools_registry:
|
||
# Без инструментов - простой запрос
|
||
return await self.chat(
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
context=context,
|
||
on_chunk=on_chunk,
|
||
**kwargs
|
||
)
|
||
|
||
# Формируем базовый контекст — БЕЗ system message
|
||
# System message будет передаваться отдельным параметром
|
||
base_messages = []
|
||
if context:
|
||
# Фильтруем system messages из context — они будут переданы через system_prompt
|
||
for msg in context:
|
||
if msg.get("role") != "system":
|
||
base_messages.append(msg)
|
||
|
||
base_messages.append({"role": "user", "content": prompt})
|
||
|
||
tools_schema = self.get_tools_schema(tools_registry) if self.supports_tools else None
|
||
|
||
# Копируем сообщения для каждой итерации
|
||
messages = base_messages.copy()
|
||
|
||
for iteration in range(max_iterations):
|
||
# Отправляем запрос провайдеру
|
||
# system_prompt передаётся всегда — провайдер сам решит как его использовать
|
||
response = await self.chat(
|
||
prompt=None, # Уже в messages
|
||
system_prompt=system_prompt,
|
||
context=messages,
|
||
tools=tools_schema,
|
||
on_chunk=on_chunk,
|
||
**kwargs
|
||
)
|
||
|
||
if not response.success:
|
||
return response
|
||
|
||
message = response.message
|
||
if not message:
|
||
return ProviderResponse(
|
||
success=False,
|
||
error="Пустой ответ от провайдера",
|
||
provider_name=self.provider_name
|
||
)
|
||
|
||
# Если нет вызовов инструментов - возвращаем ответ
|
||
if not message.tool_calls:
|
||
return response
|
||
|
||
# Выполняем инструменты
|
||
tool_results = []
|
||
for tool_call in message.tool_calls:
|
||
# Проверяем наличие инструмента через метод .get() для поддержки ToolsRegistry
|
||
if hasattr(tools_registry, 'get'):
|
||
tool = tools_registry.get(tool_call.tool_name)
|
||
elif isinstance(tools_registry, dict):
|
||
tool = tools_registry.get(tool_call.tool_name)
|
||
else:
|
||
tool = None
|
||
|
||
if tool is not None:
|
||
try:
|
||
if hasattr(tool, 'execute'):
|
||
result = await tool.execute(
|
||
**tool_call.tool_args,
|
||
user_id=kwargs.get('user_id')
|
||
)
|
||
elif hasattr(tool, '__call__'):
|
||
result = await tool(**tool_call.tool_args)
|
||
else:
|
||
result = f"Инструмент {tool_call.tool_name} не имеет метода execute"
|
||
|
||
tool_call.result = result
|
||
tool_call.status = ToolCallStatus.SUCCESS
|
||
except Exception as e:
|
||
tool_call.error = str(e)
|
||
tool_call.status = ToolCallStatus.ERROR
|
||
result = f"Ошибка: {e}"
|
||
|
||
# Преобразуем результат в JSON-сериализуемый формат
|
||
# ToolResult имеет метод to_dict(), строки оставляем как есть
|
||
if hasattr(result, 'to_dict'):
|
||
result_serializable = result.to_dict()
|
||
else:
|
||
result_serializable = result
|
||
|
||
tool_results.append({
|
||
"tool": tool_call.tool_name,
|
||
"args": tool_call.tool_args,
|
||
"result": result_serializable,
|
||
"status": tool_call.status.value
|
||
})
|
||
else:
|
||
tool_call.error = f"Инструмент {tool_call.tool_name} не найден"
|
||
tool_call.status = ToolCallStatus.ERROR
|
||
tool_results.append({
|
||
"tool": tool_call.tool_name,
|
||
"error": tool_call.error
|
||
})
|
||
|
||
# Добавляем результаты в контекст для следующей итерации
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": message.content,
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.tool_call_id,
|
||
"name": tc.tool_name,
|
||
"arguments": tc.tool_args
|
||
}
|
||
for tc in message.tool_calls
|
||
]
|
||
})
|
||
|
||
# GigaChat требует валидный JSON в tool messages, а не Python repr строку
|
||
# Используем json.dumps для корректного форматирования
|
||
messages.append({
|
||
"role": "tool",
|
||
"content": json.dumps(tool_results, ensure_ascii=False)
|
||
})
|
||
|
||
# Обновляем системный промпт для следующей итерации
|
||
system_prompt = system_prompt or ""
|
||
|
||
# Достигли максимума итераций
|
||
return ProviderResponse(
|
||
success=True,
|
||
message=AIMessage(
|
||
content=message.content + "\n\n[Достигнут максимум итераций выполнения инструментов]",
|
||
metadata={"iterations": max_iterations}
|
||
),
|
||
provider_name=self.provider_name,
|
||
usage=response.usage
|
||
)
|