541 lines
20 KiB
Python
541 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Векторная память для ИИ-чата на основе ChromaDB + sentence-transformers.
|
||
|
||
Обеспечивает семантический поиск по истории диалогов.
|
||
Используется вместе с SQLiteMemoryStorage из memory_system.py
|
||
|
||
Модель: all-MiniLM-L6-v2 (90MB, 384 измерения) — быстрая и лёгкая.
|
||
"""
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from typing import Optional, List, Dict, Any, Tuple
|
||
from dataclasses import dataclass, field
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Импортируем модели из memory_system.py
|
||
from memory_system import Message, Fact, FactType, SQLiteMemoryStorage, MEMORY_DB_PATH
|
||
|
||
|
||
# ============================================================================
|
||
# ChromaDB хранилище
|
||
# ============================================================================
|
||
|
||
class VectorMemoryStorage:
|
||
"""
|
||
Векторное хранилище на основе ChromaDB.
|
||
|
||
Модель: all-MiniLM-L6-v2
|
||
- Размер: 90MB
|
||
- Измерения: 384
|
||
- Скорость: ~1000 эмбеддингов/сек на CPU
|
||
"""
|
||
|
||
def __init__(self, persist_directory: str = None, model_name: str = "all-MiniLM-L6-v2"):
|
||
"""
|
||
Инициализация ChromaDB и модели эмбеддингов.
|
||
"""
|
||
self.persist_directory = persist_directory
|
||
self.model_name = model_name
|
||
self._client = None
|
||
self._collection = None
|
||
self._embedding_model = None
|
||
|
||
self._init_db()
|
||
|
||
def _init_db(self):
|
||
"""Инициализация клиента ChromaDB и модели."""
|
||
import chromadb
|
||
from chromadb.config import Settings
|
||
|
||
# Инициализация клиента
|
||
if self.persist_directory:
|
||
self._client = chromadb.PersistentClient(
|
||
path=self.persist_directory,
|
||
settings=Settings(
|
||
anonymized_telemetry=False,
|
||
allow_reset=True
|
||
)
|
||
)
|
||
logger.info(f"ChromaDB инициализирован (persistent): {self.persist_directory}")
|
||
else:
|
||
self._client = chromadb.EphemeralClient()
|
||
logger.info("ChromaDB инициализирован (in-memory)")
|
||
|
||
# Создаём коллекцию
|
||
self._collection = self._client.get_or_create_collection(
|
||
name="telegram_messages",
|
||
metadata={"description": "История диалогов Telegram бота"}
|
||
)
|
||
logger.info(f"Коллекция готова: {self._collection.name}")
|
||
|
||
def _get_embedding_model(self):
|
||
"""Ленивая загрузка модели эмбеддингов."""
|
||
if self._embedding_model is None:
|
||
from sentence_transformers import SentenceTransformer
|
||
self._embedding_model = SentenceTransformer(self.model_name)
|
||
logger.info(f"Модель эмбеддингов загружена: {self.model_name}")
|
||
return self._embedding_model
|
||
|
||
def _compute_embedding(self, text: str) -> List[float]:
|
||
"""Вычислить эмбеддинг текста."""
|
||
model = self._get_embedding_model()
|
||
embedding = model.encode(text, convert_to_numpy=True)
|
||
return embedding.tolist()
|
||
|
||
def add_message(self, message: Message) -> str:
|
||
"""Добавить сообщение в векторное хранилище."""
|
||
import uuid
|
||
|
||
doc_id = str(uuid.uuid4())
|
||
embedding = self._compute_embedding(message.content)
|
||
|
||
metadata = {
|
||
"user_id": str(message.user_id),
|
||
"role": message.role,
|
||
"timestamp": message.timestamp.isoformat() if message.timestamp else datetime.now().isoformat(),
|
||
"session_id": message.session_id or "unknown"
|
||
}
|
||
|
||
self._collection.add(
|
||
ids=[doc_id],
|
||
embeddings=[embedding],
|
||
documents=[message.content],
|
||
metadatas=[metadata]
|
||
)
|
||
|
||
logger.debug(f"Добавлено сообщение в векторную БД: user={message.user_id}, len={len(message.content)}")
|
||
return doc_id
|
||
|
||
def add_messages_batch(self, messages: List[Message]) -> List[str]:
|
||
"""Добавить пакет сообщений."""
|
||
import uuid
|
||
|
||
if not messages:
|
||
return []
|
||
|
||
ids = [str(uuid.uuid4()) for _ in messages]
|
||
documents = [msg.content for msg in messages]
|
||
|
||
# Вычисляем эмбеддинги батчем (быстрее)
|
||
model = self._get_embedding_model()
|
||
embeddings = model.encode(documents, convert_to_numpy=True).tolist()
|
||
|
||
metadatas = [
|
||
{
|
||
"user_id": str(msg.user_id),
|
||
"role": msg.role,
|
||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else datetime.now().isoformat(),
|
||
"session_id": msg.session_id or "unknown"
|
||
}
|
||
for msg in messages
|
||
]
|
||
|
||
self._collection.add(
|
||
ids=ids,
|
||
embeddings=embeddings,
|
||
documents=documents,
|
||
metadatas=metadatas
|
||
)
|
||
|
||
logger.info(f"Добавлено {len(messages)} сообщений в векторную БД")
|
||
return ids
|
||
|
||
def search_similar(
|
||
self,
|
||
user_id: int,
|
||
query: str,
|
||
limit: int = 5,
|
||
role_filter: Optional[str] = None
|
||
) -> List[Tuple[Message, float]]:
|
||
"""Семантический поиск похожих сообщений."""
|
||
# Вычисляем эмбеддинг запроса
|
||
query_embedding = self._compute_embedding(query)
|
||
|
||
# Фильтр по пользователю
|
||
where_filter = {"user_id": str(user_id)}
|
||
if role_filter:
|
||
where_filter = {"$and": [{"user_id": str(user_id)}, {"role": role_filter}]}
|
||
|
||
# Поиск
|
||
results = self._collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=limit,
|
||
where=where_filter,
|
||
include=["documents", "metadatas", "distances"]
|
||
)
|
||
|
||
# Преобразуем результаты
|
||
found_messages = []
|
||
|
||
if results and results['ids'] and results['ids'][0]:
|
||
for i, doc_id in enumerate(results['ids'][0]):
|
||
doc_text = results['documents'][0][i]
|
||
metadata = results['metadatas'][0][i]
|
||
distance = results['distances'][0][i] if results['distances'] else 0.0
|
||
|
||
message = Message(
|
||
id=None,
|
||
user_id=int(metadata['user_id']),
|
||
role=metadata['role'],
|
||
content=doc_text,
|
||
timestamp=datetime.fromisoformat(metadata['timestamp']),
|
||
session_id=metadata.get('session_id')
|
||
)
|
||
|
||
found_messages.append((message, distance))
|
||
|
||
logger.debug(f"Векторный поиск: query='{query[:30]}...', found={len(found_messages)}")
|
||
return found_messages
|
||
|
||
def search_by_session(
|
||
self,
|
||
session_id: str,
|
||
query: str = None,
|
||
limit: int = 20
|
||
) -> List[Message]:
|
||
"""Получить сообщения из сессии."""
|
||
where_filter = {"session_id": session_id}
|
||
|
||
if query:
|
||
query_embedding = self._compute_embedding(query)
|
||
results = self._collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=limit,
|
||
where=where_filter,
|
||
include=["documents", "metadatas"]
|
||
)
|
||
else:
|
||
# Получаем все сообщения сессии
|
||
results = self._collection.get(
|
||
where=where_filter,
|
||
include=["documents", "metadatas"],
|
||
limit=limit
|
||
)
|
||
|
||
messages = []
|
||
if results and results.get('ids') and results['ids'][0]:
|
||
for i, doc_id in enumerate(results['ids'][0]):
|
||
doc_text = results['documents'][0][i] if 'documents' in results else ""
|
||
metadata = results['metadatas'][0][i] if 'metadatas' in results else {}
|
||
|
||
message = Message(
|
||
id=None,
|
||
user_id=int(metadata.get('user_id', 0)),
|
||
role=metadata.get('role', 'user'),
|
||
content=doc_text,
|
||
timestamp=datetime.fromisoformat(metadata.get('timestamp', datetime.now().isoformat())),
|
||
session_id=metadata.get('session_id')
|
||
)
|
||
messages.append(message)
|
||
|
||
return messages
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""Получить статистику коллекции."""
|
||
count = self._collection.count()
|
||
return {
|
||
"total_documents": count,
|
||
"collection_name": self._collection.name,
|
||
"model": self.model_name
|
||
}
|
||
|
||
def delete_user_data(self, user_id: int) -> int:
|
||
"""Удалить все данные пользователя."""
|
||
results = self._collection.get(
|
||
where={"user_id": str(user_id)},
|
||
include=[]
|
||
)
|
||
|
||
if results and results.get('ids'):
|
||
count = len(results['ids'])
|
||
self._collection.delete(ids=results['ids'])
|
||
logger.info(f"Удалено {count} документов пользователя {user_id}")
|
||
return count
|
||
return 0
|
||
|
||
|
||
# ============================================================================
|
||
# Гибридный менеджер памяти (SQLite + Vector)
|
||
# ============================================================================
|
||
|
||
class HybridMemoryManager:
|
||
"""
|
||
Гибридный менеджер памяти.
|
||
|
||
Объединяет:
|
||
- SQLiteMemoryStorage для хранения фактов и истории
|
||
- VectorMemoryStorage для семантического поиска
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
sqlite_storage: SQLiteMemoryStorage,
|
||
vector_storage: VectorMemoryStorage = None,
|
||
ai_client=None
|
||
):
|
||
self.sqlite = sqlite_storage
|
||
self.vector = vector_storage
|
||
self.ai_client = ai_client
|
||
self._active_sessions: Dict[int, str] = {}
|
||
|
||
def start_session(self, user_id: int) -> str:
|
||
"""Начать новую сессию."""
|
||
import uuid
|
||
session_id = str(uuid.uuid4())
|
||
|
||
from memory_system import DialogSession
|
||
session = DialogSession(id=session_id, user_id=user_id)
|
||
self.sqlite.create_session(session)
|
||
self._active_sessions[user_id] = session_id
|
||
|
||
logger.info(f"Начата новая сессия {session_id} для пользователя {user_id}")
|
||
return session_id
|
||
|
||
def end_session(self, user_id: int, summary: str = None):
|
||
"""Завершить сессию."""
|
||
session_id = self._active_sessions.pop(user_id, None)
|
||
if session_id:
|
||
self.sqlite.close_session(session_id, summary)
|
||
logger.info(f"Завершена сессия {session_id} для пользователя {user_id}")
|
||
|
||
def get_session_id(self, user_id: int) -> Optional[str]:
|
||
"""Получить ID текущей сессии."""
|
||
if user_id in self._active_sessions:
|
||
return self._active_sessions[user_id]
|
||
|
||
session = self.sqlite.get_active_session(user_id)
|
||
if session:
|
||
self._active_sessions[user_id] = session.id
|
||
return session.id
|
||
|
||
return self.start_session(user_id)
|
||
|
||
def add_message(self, user_id: int, role: str, content: str) -> int:
|
||
"""Добавить сообщение в оба хранилища."""
|
||
from memory_system import Message
|
||
|
||
session_id = self.get_session_id(user_id)
|
||
message = Message(
|
||
id=None,
|
||
user_id=user_id,
|
||
role=role,
|
||
content=content,
|
||
session_id=session_id
|
||
)
|
||
|
||
# Сохраняем в SQLite
|
||
sqlite_id = self.sqlite.save_message(message)
|
||
|
||
# Сохраняем в векторную БД
|
||
if self.vector:
|
||
try:
|
||
self.vector.add_message(message)
|
||
except Exception as e:
|
||
logger.error(f"Ошибка сохранения в векторную БД: {e}")
|
||
|
||
return sqlite_id
|
||
|
||
def get_context(self, user_id: int, max_messages: int = 10) -> List[Message]:
|
||
"""Получить контекст для ИИ (последние сообщения)."""
|
||
return self.sqlite.get_recent_messages(user_id, max_messages)
|
||
|
||
def search_relevant(
|
||
self,
|
||
user_id: int,
|
||
query: str,
|
||
max_results: int = 5,
|
||
use_vector: bool = True
|
||
) -> List[Tuple[Message, float]]:
|
||
"""Найти релевантные сообщения."""
|
||
# Приоритет векторному поиску
|
||
if use_vector and self.vector:
|
||
try:
|
||
results = self.vector.search_similar(
|
||
user_id=user_id,
|
||
query=query,
|
||
limit=max_results
|
||
)
|
||
logger.info(f"Векторный поиск: найдено {len(results)} результатов")
|
||
return results
|
||
except Exception as e:
|
||
logger.error(f"Ошибка векторного поиска, используем SQLite: {e}")
|
||
|
||
# Фоллбэк на SQLite LIKE поиск
|
||
messages = self.sqlite.search_messages(user_id, query, max_results)
|
||
return [(msg, 0.5) for msg in messages]
|
||
|
||
def get_user_profile(self, user_id: int) -> Dict[str, List[str]]:
|
||
"""Получить профиль пользователя (факты)."""
|
||
facts = self.sqlite.get_facts(user_id)
|
||
profile = {}
|
||
|
||
for fact in facts:
|
||
type_name = fact.fact_type.value
|
||
if type_name not in profile:
|
||
profile[type_name] = []
|
||
profile[type_name].append(fact.content)
|
||
|
||
return profile
|
||
|
||
def extract_and_save_facts(self, user_id: int, message: str, response: str = None):
|
||
"""Извлечь факты из сообщения и сохранить."""
|
||
import re
|
||
from memory_system import Fact, FactType
|
||
|
||
extracted = []
|
||
message_lower = message.lower()
|
||
|
||
# Имя
|
||
if "меня зовут" in message_lower:
|
||
parts = message.split("меня зовут")
|
||
if len(parts) > 1:
|
||
name = parts[1].strip().split()[0]
|
||
fact = Fact(
|
||
id=None,
|
||
user_id=user_id,
|
||
fact_type=FactType.PERSONAL,
|
||
content=f"Пользователя зовут {name}",
|
||
source_message=message,
|
||
confidence=0.8
|
||
)
|
||
self.sqlite.save_fact(fact)
|
||
extracted.append(fact)
|
||
|
||
# Технологии
|
||
tech_patterns = [
|
||
(r"я (люблю|предпочитаю|использую)\s+(\w+)", FactType.TECHNICAL),
|
||
(r"мой (язык|стек)\s+(\w+)", FactType.TECHNICAL),
|
||
]
|
||
|
||
for pattern, fact_type in tech_patterns:
|
||
match = re.search(pattern, message_lower)
|
||
if match:
|
||
tech = match.group(2) if len(match.groups()) > 1 else match.group(1)
|
||
fact = Fact(
|
||
id=None,
|
||
user_id=user_id,
|
||
fact_type=fact_type,
|
||
content=f"Использует {tech}",
|
||
source_message=message,
|
||
confidence=0.6
|
||
)
|
||
self.sqlite.save_fact(fact)
|
||
extracted.append(fact)
|
||
|
||
if extracted:
|
||
logger.info(f"Извлечено {len(extracted)} фактов для пользователя {user_id}")
|
||
|
||
def format_context_for_ai(self, user_id: int, query: str = None) -> str:
|
||
"""Сформировать контекст для передачи ИИ."""
|
||
parts = []
|
||
|
||
# Профиль
|
||
profile = self.get_user_profile(user_id)
|
||
if profile:
|
||
parts.append("📋 ПРОФИЛЬ ПОЛЬЗОВАТЕЛЯ:")
|
||
for fact_type, facts in profile.items():
|
||
parts.append(f" [{fact_type}]:")
|
||
for f in facts:
|
||
parts.append(f" - {f}")
|
||
|
||
# Последние сообщения
|
||
recent = self.get_context(user_id, 5)
|
||
if recent:
|
||
parts.append("\n💬 ПОСЛЕДНИЕ СООБЩЕНИЯ:")
|
||
for msg in recent:
|
||
role_ru = "Пользователь" if msg.role == "user" else "Ассистент"
|
||
preview = msg.content[:100].replace('\n', ' ')
|
||
parts.append(f" {role_ru}: {preview}...")
|
||
|
||
# Релевантный поиск
|
||
if query:
|
||
relevant = self.search_relevant(user_id, query, max_results=3)
|
||
if relevant:
|
||
parts.append("\n🔍 РЕЛЕВАНТНЫЕ СООБЩЕНИЯ:")
|
||
for msg, score in relevant:
|
||
preview = msg.content[:100].replace('\n', ' ')
|
||
parts.append(f" [{score:.2f}] {preview}...")
|
||
|
||
return "\n".join(parts)
|
||
|
||
def get_stats(self, user_id: int) -> Dict[str, Any]:
|
||
"""Получить статистику памяти пользователя."""
|
||
sqlite_stats = self.sqlite.get_user_stats(user_id)
|
||
|
||
stats = {
|
||
**sqlite_stats,
|
||
"hybrid_mode": self.vector is not None
|
||
}
|
||
|
||
if self.vector:
|
||
try:
|
||
vector_stats = self.vector.get_stats()
|
||
stats["vector_documents"] = vector_stats.get("total_documents", 0)
|
||
stats["vector_model"] = vector_stats.get("model", "unknown")
|
||
except Exception as e:
|
||
logger.error(f"Ошибка получения статистики векторной БД: {e}")
|
||
stats["vector_documents"] = "N/A"
|
||
stats["vector_model"] = "N/A"
|
||
|
||
return stats
|
||
|
||
|
||
# ============================================================================
|
||
# Глобальные экземпляры
|
||
# ============================================================================
|
||
|
||
VECTOR_DB_PATH = str(Path(__file__).parent / "vector_db")
|
||
|
||
# Создаём гибридный менеджер
|
||
sqlite_storage = SQLiteMemoryStorage(MEMORY_DB_PATH)
|
||
vector_storage = VectorMemoryStorage(VECTOR_DB_PATH)
|
||
|
||
hybrid_memory_manager = HybridMemoryManager(
|
||
sqlite_storage=sqlite_storage,
|
||
vector_storage=vector_storage
|
||
)
|
||
|
||
|
||
# ============================================================================
|
||
# Хелперы для бота
|
||
# ============================================================================
|
||
|
||
def save_message(user_id: int, role: str, content: str):
|
||
"""Сохранить сообщение в гибридную память."""
|
||
if hybrid_memory_manager:
|
||
hybrid_memory_manager.add_message(user_id, role, content)
|
||
if role == "user":
|
||
hybrid_memory_manager.extract_and_save_facts(user_id, content)
|
||
|
||
|
||
def get_context(user_id: int, query: str = None) -> str:
|
||
"""Получить форматированный контекст для ИИ."""
|
||
if hybrid_memory_manager:
|
||
return hybrid_memory_manager.format_context_for_ai(user_id, query)
|
||
return ""
|
||
|
||
|
||
def search_memory(user_id: int, query: str, limit: int = 5) -> List[Tuple[Message, float]]:
|
||
"""Поиск в памяти."""
|
||
if hybrid_memory_manager:
|
||
return hybrid_memory_manager.search_relevant(user_id, query, limit)
|
||
return []
|
||
|
||
|
||
def get_profile(user_id: int) -> Dict[str, List[str]]:
|
||
"""Получить профиль пользователя."""
|
||
if hybrid_memory_manager:
|
||
return hybrid_memory_manager.get_user_profile(user_id)
|
||
return {}
|
||
|
||
|
||
def get_memory_stats(user_id: int) -> Dict[str, Any]:
|
||
"""Получить статистику памяти."""
|
||
if hybrid_memory_manager:
|
||
return hybrid_memory_manager.get_stats(user_id)
|
||
return {}
|