520 lines
20 KiB
Python
520 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
GigaChat AI Provider - адаптер GigaChat для работы с инструментами.
|
||
|
||
Реализует интерфейс BaseAIProvider для единой работы с инструментами
|
||
независимо от AI-провайдера.
|
||
|
||
Использует нативный GigaChat Function Calling API:
|
||
https://developers.sber.ru/docs/ru/gigachat/guides/functions/overview
|
||
"""
|
||
|
||
import logging
|
||
from typing import Optional, Dict, Any, Callable, List
|
||
import json
|
||
|
||
from bot.base_ai_provider import (
|
||
BaseAIProvider,
|
||
ProviderResponse,
|
||
AIMessage,
|
||
ToolCall,
|
||
ToolCallStatus,
|
||
)
|
||
from bot.tools.gigachat_tool import GigaChatTool, GigaChatMessage, GigaChatConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class GigaChatProvider(BaseAIProvider):
|
||
"""
|
||
GigaChat AI Provider с нативной поддержкой function calling.
|
||
|
||
Использует официальный GigaChat Function Calling API вместо
|
||
эмуляции через текстовые блоки.
|
||
"""
|
||
|
||
def __init__(self, config: Optional[GigaChatConfig] = None):
|
||
self._tool = GigaChatTool(config)
|
||
self._available: Optional[bool] = None
|
||
self._functions_state_id: Optional[str] = None
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "GigaChat"
|
||
|
||
@property
|
||
def supports_tools(self) -> bool:
|
||
# GigaChat поддерживает нативные function calls
|
||
return True
|
||
|
||
@property
|
||
def supports_streaming(self) -> bool:
|
||
return False
|
||
|
||
def is_available(self) -> bool:
|
||
"""Проверить доступность GigaChat."""
|
||
if self._available is not None:
|
||
return self._available
|
||
|
||
try:
|
||
import os
|
||
client_id = os.getenv("GIGACHAT_CLIENT_ID")
|
||
client_secret = os.getenv("GIGACHAT_CLIENT_SECRET")
|
||
|
||
self._available = bool(client_id and client_secret)
|
||
|
||
if not self._available:
|
||
logger.warning("GigaChat недоступен: не настроены GIGACHAT_CLIENT_ID или GIGACHAT_CLIENT_SECRET")
|
||
else:
|
||
logger.info("GigaChat доступен")
|
||
except Exception as e:
|
||
self._available = False
|
||
logger.error(f"Ошибка проверки доступности GigaChat: {e}")
|
||
|
||
return self._available
|
||
|
||
def get_error(self) -> Optional[str]:
|
||
"""Получить последнюю ошибку."""
|
||
if self._available is False:
|
||
return "GigaChat недоступен: проверьте GIGACHAT_CLIENT_ID и GIGACHAT_CLIENT_SECRET"
|
||
return None
|
||
|
||
def get_functions_schema(self, tools_registry: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""
|
||
Получить схему функций для GigaChat API в правильном формате.
|
||
|
||
Формат GigaChat:
|
||
{
|
||
"name": "function_name",
|
||
"description": "Описание функции",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {...},
|
||
"required": [...]
|
||
},
|
||
"return_parameters": {...} # опционально
|
||
}
|
||
"""
|
||
schema = []
|
||
|
||
if tools_registry is None:
|
||
return schema
|
||
|
||
# Обрабатываем разные типы tools_registry
|
||
items = []
|
||
if hasattr(tools_registry, 'get_all') and callable(getattr(tools_registry, 'get_all')):
|
||
items = list(tools_registry.get_all().items())
|
||
elif isinstance(tools_registry, dict):
|
||
items = list(tools_registry.items())
|
||
elif hasattr(tools_registry, 'tools'):
|
||
items = list(tools_registry.tools.items()) if isinstance(tools_registry.tools, dict) else []
|
||
|
||
for name, tool in items:
|
||
if hasattr(tool, 'get_schema'):
|
||
tool_schema = tool.get_schema()
|
||
# Преобразуем в формат GigaChat с гарантией наличия properties
|
||
parameters = tool_schema.get("parameters", {})
|
||
if not parameters:
|
||
parameters = {"type": "object", "properties": {}}
|
||
elif "properties" not in parameters:
|
||
parameters["properties"] = {}
|
||
|
||
giga_schema = {
|
||
"name": name,
|
||
"description": tool_schema.get("description", ""),
|
||
"parameters": parameters
|
||
}
|
||
# Добавляем return_parameters если есть
|
||
if hasattr(tool, 'get_return_schema'):
|
||
giga_schema["return_parameters"] = tool.get_return_schema()
|
||
schema.append(giga_schema)
|
||
elif hasattr(tool, 'description'):
|
||
schema.append({
|
||
"name": name,
|
||
"description": tool.description,
|
||
"parameters": {"type": "object", "properties": {}} # Пустая но валидная схема
|
||
})
|
||
|
||
logger.info(f"📋 GigaChat functions schema: {[f['name'] for f in schema]}")
|
||
return schema
|
||
|
||
def _parse_function_call(self, function_call: Dict[str, Any]) -> ToolCall:
|
||
"""
|
||
Преобразовать function_call из ответа GigaChat в ToolCall.
|
||
|
||
GigaChat возвращает:
|
||
{
|
||
"name": "function_name",
|
||
"arguments": {"arg1": "value1", ...}
|
||
}
|
||
"""
|
||
try:
|
||
# Аргументы могут быть строкой JSON или уже dict
|
||
args = function_call.get("arguments", {})
|
||
if isinstance(args, str):
|
||
args = json.loads(args)
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
logger.warning(f"Ошибка парсинга аргументов function_call: {e}")
|
||
args = {}
|
||
|
||
return ToolCall(
|
||
tool_name=function_call.get("name", "unknown"),
|
||
tool_args=args,
|
||
tool_call_id=function_call.get("name", "fc_0") # Используем name как ID
|
||
)
|
||
|
||
async def process_with_tools(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
context: Optional[List[Dict[str, str]]] = None,
|
||
tools_registry: Optional[Dict[str, Any]] = None,
|
||
on_chunk: Optional[Callable[[str], Any]] = None,
|
||
max_iterations: int = 5,
|
||
**kwargs
|
||
) -> ProviderResponse:
|
||
"""
|
||
Обработка запросов с function calling для GigaChat.
|
||
|
||
Использует нативный GigaChat Function Calling API:
|
||
1. Отправляем запрос с functions массивом
|
||
2. Получаем function_call из ответа
|
||
3. Выполняем инструмент
|
||
4. Отправляем результат с role: "function"
|
||
5. Повторяем пока не будет финального ответа
|
||
|
||
Формат сообщений:
|
||
- user: {"role": "user", "content": "..."}
|
||
- assistant: {"role": "assistant", "function_call": {...}}
|
||
- function: {"role": "function", "name": "...", "content": "..."}
|
||
"""
|
||
if not tools_registry:
|
||
return await self.chat(
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
context=context,
|
||
on_chunk=on_chunk,
|
||
**kwargs
|
||
)
|
||
|
||
# Формируем базовые сообщения
|
||
messages = []
|
||
|
||
# Добавляем системный промпт если есть
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
|
||
# Добавляем контекст (историю диалога)
|
||
if context:
|
||
for msg in context:
|
||
role = msg.get("role")
|
||
# Пропускаем system messages — они уже добавлены
|
||
if role == "system":
|
||
continue
|
||
# Преобразуем tool messages в function messages
|
||
if role == "tool":
|
||
role = "function"
|
||
if role in ("user", "assistant", "function"):
|
||
messages.append({
|
||
"role": role,
|
||
"content": msg.get("content", ""),
|
||
"name": msg.get("name") # Для function messages
|
||
})
|
||
|
||
# Добавляем текущий запрос пользователя
|
||
if prompt:
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
# Получаем схему функций
|
||
functions = self.get_functions_schema(tools_registry) if self.supports_tools else None
|
||
|
||
logger.info(f"🔍 GigaChat process_with_tools: {len(messages)} сообщений, {len(functions) if functions else 0} функций")
|
||
|
||
for iteration in range(max_iterations):
|
||
logger.info(f"🔄 Итерация {iteration + 1}/{max_iterations}")
|
||
|
||
# Логируем сообщения перед отправкой
|
||
for i, msg in enumerate(messages[-3:]): # Последние 3 сообщения
|
||
content_preview = msg.get("content", "")[:100]
|
||
logger.info(f" 📨 [{i}] role={msg.get('role')}, content='{content_preview}...'")
|
||
|
||
# Отправляем запрос с functions
|
||
response = await self._chat_with_functions(
|
||
messages=messages,
|
||
functions=functions,
|
||
user_id=kwargs.get('user_id'),
|
||
temperature=kwargs.get("temperature", 0.7),
|
||
max_tokens=kwargs.get("max_tokens", 2000),
|
||
)
|
||
|
||
if not response.get("success"):
|
||
return ProviderResponse(
|
||
success=False,
|
||
error=response.get("error", "Неизвестная ошибка"),
|
||
provider_name=self.provider_name
|
||
)
|
||
|
||
# Проверяем наличие function_call
|
||
function_call = response.get("function_call")
|
||
content = response.get("content", "")
|
||
|
||
logger.info(f"📬 Ответ GigaChat: content_len={len(content) if content else 0}, function_call={function_call is not None}")
|
||
|
||
# Если нет function_call — возвращаем финальный ответ
|
||
if not function_call:
|
||
return ProviderResponse(
|
||
success=True,
|
||
message=AIMessage(
|
||
content=content,
|
||
tool_calls=[],
|
||
metadata={
|
||
"model": response.get("model", "GigaChat"),
|
||
"usage": response.get("usage", {}),
|
||
"functions_state_id": response.get("functions_state_id")
|
||
}
|
||
),
|
||
provider_name=self.provider_name,
|
||
usage=response.get("usage")
|
||
)
|
||
|
||
# Есть function_call — парсим и выполняем инструмент
|
||
tool_call = self._parse_function_call(function_call)
|
||
logger.info(f"🛠️ Function call: {tool_call.tool_name}({tool_call.tool_args})")
|
||
|
||
# Выполняем инструмент
|
||
if hasattr(tools_registry, 'get'):
|
||
tool = tools_registry.get(tool_call.tool_name)
|
||
elif isinstance(tools_registry, dict):
|
||
tool = tools_registry.get(tool_call.tool_name)
|
||
else:
|
||
tool = None
|
||
|
||
if tool is not None:
|
||
try:
|
||
if hasattr(tool, 'execute'):
|
||
result = await tool.execute(
|
||
**tool_call.tool_args,
|
||
user_id=kwargs.get('user_id')
|
||
)
|
||
elif hasattr(tool, '__call__'):
|
||
result = await tool(**tool_call.tool_args)
|
||
else:
|
||
result = f"Инструмент {tool_call.tool_name} не имеет метода execute"
|
||
|
||
tool_call.result = result
|
||
tool_call.status = ToolCallStatus.SUCCESS
|
||
except Exception as e:
|
||
logger.exception(f"Ошибка выполнения инструмента {tool_call.tool_name}: {e}")
|
||
tool_call.error = str(e)
|
||
tool_call.status = ToolCallStatus.ERROR
|
||
result = {"error": str(e)}
|
||
else:
|
||
tool_call.error = f"Инструмент {tool_call.tool_name} не найден"
|
||
tool_call.status = ToolCallStatus.ERROR
|
||
result = {"error": tool_call.error}
|
||
|
||
# Сериализуем результат
|
||
if hasattr(result, 'to_dict'):
|
||
result_dict = result.to_dict()
|
||
elif isinstance(result, dict):
|
||
result_dict = result
|
||
else:
|
||
result_dict = {"result": str(result)}
|
||
|
||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||
|
||
# Добавляем assistant message с function_call
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": "", # Пустой content при function_call
|
||
"function_call": function_call
|
||
})
|
||
|
||
# Добавляем function message с результатом
|
||
messages.append({
|
||
"role": "function",
|
||
"name": tool_call.tool_name,
|
||
"content": result_json
|
||
})
|
||
|
||
logger.info(f"✅ Добавлен function result: {tool_call.tool_name}, result_len={len(result_json)}")
|
||
|
||
# Сохраняем functions_state_id для следующей итерации
|
||
if response.get("functions_state_id"):
|
||
self._functions_state_id = response["functions_state_id"]
|
||
|
||
# Достигли максимума итераций
|
||
return ProviderResponse(
|
||
success=True,
|
||
message=AIMessage(
|
||
content=content + "\n\n[Достигнут максимум итераций выполнения функций]",
|
||
metadata={"iterations": max_iterations}
|
||
),
|
||
provider_name=self.provider_name,
|
||
usage=response.get("usage")
|
||
)
|
||
|
||
async def _chat_with_functions(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
functions: Optional[List[Dict[str, Any]]] = None,
|
||
user_id: Optional[int] = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2000,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Отправить запрос в GigaChat API с поддержкой function calling.
|
||
|
||
Возвращает:
|
||
{
|
||
"success": bool,
|
||
"content": str,
|
||
"function_call": {"name": str, "arguments": dict} или None,
|
||
"model": str,
|
||
"usage": dict,
|
||
"functions_state_id": str или None
|
||
}
|
||
"""
|
||
try:
|
||
# Формируем сообщения в формате GigaChat
|
||
gc_messages = []
|
||
for msg in messages:
|
||
gc_msg = {"role": msg["role"], "content": msg.get("content", "")}
|
||
if msg.get("name"):
|
||
gc_msg["name"] = msg["name"]
|
||
if msg.get("function_call"):
|
||
gc_msg["function_call"] = msg["function_call"]
|
||
gc_messages.append(gc_msg)
|
||
|
||
# Выполняем запрос через GigaChatTool
|
||
result = await self._tool.chat_with_functions(
|
||
messages=gc_messages,
|
||
functions=functions,
|
||
user_id=str(user_id) if user_id else None,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
)
|
||
|
||
# Извлекаем function_call из ответа
|
||
function_call = None
|
||
if result.get("choices"):
|
||
choice = result["choices"][0]
|
||
message = choice.get("message", {})
|
||
function_call = message.get("function_call")
|
||
|
||
return {
|
||
"success": True,
|
||
"content": result.get("content", ""),
|
||
"function_call": function_call,
|
||
"model": result.get("model", "GigaChat"),
|
||
"usage": result.get("usage", {}),
|
||
"functions_state_id": result.get("functions_state_id")
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.exception(f"Ошибка _chat_with_functions: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"function_call": None
|
||
}
|
||
|
||
async def chat(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
context: Optional[List[Dict[str, str]]] = None,
|
||
tools: Optional[List[Dict[str, Any]]] = None,
|
||
on_chunk: Optional[Callable[[str], Any]] = None,
|
||
user_id: Optional[int] = None,
|
||
**kwargs
|
||
) -> ProviderResponse:
|
||
"""
|
||
Отправить запрос GigaChat (без function calling).
|
||
|
||
Используется когда tools не переданы.
|
||
"""
|
||
try:
|
||
# Формируем сообщения
|
||
messages = []
|
||
|
||
if system_prompt:
|
||
messages.append(GigaChatMessage(role="system", content=system_prompt))
|
||
|
||
if context:
|
||
for msg in context:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
if role == "system":
|
||
continue
|
||
if role in ("user", "assistant"):
|
||
messages.append(GigaChatMessage(role=role, content=content))
|
||
|
||
if prompt:
|
||
messages.append(GigaChatMessage(role="user", content=prompt))
|
||
|
||
# Выполняем запрос
|
||
result = await self._tool.chat(
|
||
messages=messages,
|
||
user_id=str(user_id) if user_id else None,
|
||
temperature=kwargs.get("temperature", 0.7),
|
||
max_tokens=kwargs.get("max_tokens", 2000),
|
||
)
|
||
|
||
if not result.get("content"):
|
||
if result.get("error"):
|
||
return ProviderResponse(
|
||
success=False,
|
||
error=result["error"],
|
||
provider_name=self.provider_name
|
||
)
|
||
else:
|
||
return ProviderResponse(
|
||
success=False,
|
||
error="Пустой ответ от GigaChat",
|
||
provider_name=self.provider_name
|
||
)
|
||
|
||
content = result["content"]
|
||
|
||
return ProviderResponse(
|
||
success=True,
|
||
message=AIMessage(
|
||
content=content,
|
||
tool_calls=[],
|
||
metadata={
|
||
"model": result.get("model", "GigaChat"),
|
||
"usage": result.get("usage", {})
|
||
}
|
||
),
|
||
provider_name=self.provider_name,
|
||
usage=result.get("usage")
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Ошибка GigaChat провайдера: {e}")
|
||
return ProviderResponse(
|
||
success=False,
|
||
error=str(e),
|
||
provider_name=self.provider_name
|
||
)
|
||
|
||
async def execute_tool(
|
||
self,
|
||
tool_name: str,
|
||
tool_args: Dict[str, Any],
|
||
tool_call_id: Optional[str] = None,
|
||
**kwargs
|
||
) -> ToolCall:
|
||
"""
|
||
Выполнить инструмент (заглушка).
|
||
|
||
Инструменты выполняются через process_with_tools.
|
||
"""
|
||
return ToolCall(
|
||
tool_name=tool_name,
|
||
tool_args=tool_args,
|
||
tool_call_id=tool_call_id,
|
||
status=ToolCallStatus.PENDING
|
||
)
|