telegram-cli-bot/vector_memory.py

541 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Векторная память для ИИ-чата на основе ChromaDB + sentence-transformers.
Обеспечивает семантический поиск по истории диалогов.
Используется вместе с SQLiteMemoryStorage из memory_system.py
Модель: all-MiniLM-L6-v2 (90MB, 384 измерения) — быстрая и лёгкая.
"""
import logging
from pathlib import Path
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# Импортируем модели из memory_system.py
from memory_system import Message, Fact, FactType, SQLiteMemoryStorage, MEMORY_DB_PATH
# ============================================================================
# ChromaDB хранилище
# ============================================================================
class VectorMemoryStorage:
"""
Векторное хранилище на основе ChromaDB.
Модель: all-MiniLM-L6-v2
- Размер: 90MB
- Измерения: 384
- Скорость: ~1000 эмбеддингов/сек на CPU
"""
def __init__(self, persist_directory: str = None, model_name: str = "all-MiniLM-L6-v2"):
"""
Инициализация ChromaDB и модели эмбеддингов.
"""
self.persist_directory = persist_directory
self.model_name = model_name
self._client = None
self._collection = None
self._embedding_model = None
self._init_db()
def _init_db(self):
"""Инициализация клиента ChromaDB и модели."""
import chromadb
from chromadb.config import Settings
# Инициализация клиента
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
logger.info(f"ChromaDB инициализирован (persistent): {self.persist_directory}")
else:
self._client = chromadb.EphemeralClient()
logger.info("ChromaDB инициализирован (in-memory)")
# Создаём коллекцию
self._collection = self._client.get_or_create_collection(
name="telegram_messages",
metadata={"description": "История диалогов Telegram бота"}
)
logger.info(f"Коллекция готова: {self._collection.name}")
def _get_embedding_model(self):
"""Ленивая загрузка модели эмбеддингов."""
if self._embedding_model is None:
from sentence_transformers import SentenceTransformer
self._embedding_model = SentenceTransformer(self.model_name)
logger.info(f"Модель эмбеддингов загружена: {self.model_name}")
return self._embedding_model
def _compute_embedding(self, text: str) -> List[float]:
"""Вычислить эмбеддинг текста."""
model = self._get_embedding_model()
embedding = model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def add_message(self, message: Message) -> str:
"""Добавить сообщение в векторное хранилище."""
import uuid
doc_id = str(uuid.uuid4())
embedding = self._compute_embedding(message.content)
metadata = {
"user_id": str(message.user_id),
"role": message.role,
"timestamp": message.timestamp.isoformat() if message.timestamp else datetime.now().isoformat(),
"session_id": message.session_id or "unknown"
}
self._collection.add(
ids=[doc_id],
embeddings=[embedding],
documents=[message.content],
metadatas=[metadata]
)
logger.debug(f"Добавлено сообщение в векторную БД: user={message.user_id}, len={len(message.content)}")
return doc_id
def add_messages_batch(self, messages: List[Message]) -> List[str]:
"""Добавить пакет сообщений."""
import uuid
if not messages:
return []
ids = [str(uuid.uuid4()) for _ in messages]
documents = [msg.content for msg in messages]
# Вычисляем эмбеддинги батчем (быстрее)
model = self._get_embedding_model()
embeddings = model.encode(documents, convert_to_numpy=True).tolist()
metadatas = [
{
"user_id": str(msg.user_id),
"role": msg.role,
"timestamp": msg.timestamp.isoformat() if msg.timestamp else datetime.now().isoformat(),
"session_id": msg.session_id or "unknown"
}
for msg in messages
]
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas
)
logger.info(f"Добавлено {len(messages)} сообщений в векторную БД")
return ids
def search_similar(
self,
user_id: int,
query: str,
limit: int = 5,
role_filter: Optional[str] = None
) -> List[Tuple[Message, float]]:
"""Семантический поиск похожих сообщений."""
# Вычисляем эмбеддинг запроса
query_embedding = self._compute_embedding(query)
# Фильтр по пользователю
where_filter = {"user_id": str(user_id)}
if role_filter:
where_filter = {"$and": [{"user_id": str(user_id)}, {"role": role_filter}]}
# Поиск
results = self._collection.query(
query_embeddings=[query_embedding],
n_results=limit,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
# Преобразуем результаты
found_messages = []
if results and results['ids'] and results['ids'][0]:
for i, doc_id in enumerate(results['ids'][0]):
doc_text = results['documents'][0][i]
metadata = results['metadatas'][0][i]
distance = results['distances'][0][i] if results['distances'] else 0.0
message = Message(
id=None,
user_id=int(metadata['user_id']),
role=metadata['role'],
content=doc_text,
timestamp=datetime.fromisoformat(metadata['timestamp']),
session_id=metadata.get('session_id')
)
found_messages.append((message, distance))
logger.debug(f"Векторный поиск: query='{query[:30]}...', found={len(found_messages)}")
return found_messages
def search_by_session(
self,
session_id: str,
query: str = None,
limit: int = 20
) -> List[Message]:
"""Получить сообщения из сессии."""
where_filter = {"session_id": session_id}
if query:
query_embedding = self._compute_embedding(query)
results = self._collection.query(
query_embeddings=[query_embedding],
n_results=limit,
where=where_filter,
include=["documents", "metadatas"]
)
else:
# Получаем все сообщения сессии
results = self._collection.get(
where=where_filter,
include=["documents", "metadatas"],
limit=limit
)
messages = []
if results and results.get('ids') and results['ids'][0]:
for i, doc_id in enumerate(results['ids'][0]):
doc_text = results['documents'][0][i] if 'documents' in results else ""
metadata = results['metadatas'][0][i] if 'metadatas' in results else {}
message = Message(
id=None,
user_id=int(metadata.get('user_id', 0)),
role=metadata.get('role', 'user'),
content=doc_text,
timestamp=datetime.fromisoformat(metadata.get('timestamp', datetime.now().isoformat())),
session_id=metadata.get('session_id')
)
messages.append(message)
return messages
def get_stats(self) -> Dict[str, Any]:
"""Получить статистику коллекции."""
count = self._collection.count()
return {
"total_documents": count,
"collection_name": self._collection.name,
"model": self.model_name
}
def delete_user_data(self, user_id: int) -> int:
"""Удалить все данные пользователя."""
results = self._collection.get(
where={"user_id": str(user_id)},
include=[]
)
if results and results.get('ids'):
count = len(results['ids'])
self._collection.delete(ids=results['ids'])
logger.info(f"Удалено {count} документов пользователя {user_id}")
return count
return 0
# ============================================================================
# Гибридный менеджер памяти (SQLite + Vector)
# ============================================================================
class HybridMemoryManager:
"""
Гибридный менеджер памяти.
Объединяет:
- SQLiteMemoryStorage для хранения фактов и истории
- VectorMemoryStorage для семантического поиска
"""
def __init__(
self,
sqlite_storage: SQLiteMemoryStorage,
vector_storage: VectorMemoryStorage = None,
ai_client=None
):
self.sqlite = sqlite_storage
self.vector = vector_storage
self.ai_client = ai_client
self._active_sessions: Dict[int, str] = {}
def start_session(self, user_id: int) -> str:
"""Начать новую сессию."""
import uuid
session_id = str(uuid.uuid4())
from memory_system import DialogSession
session = DialogSession(id=session_id, user_id=user_id)
self.sqlite.create_session(session)
self._active_sessions[user_id] = session_id
logger.info(f"Начата новая сессия {session_id} для пользователя {user_id}")
return session_id
def end_session(self, user_id: int, summary: str = None):
"""Завершить сессию."""
session_id = self._active_sessions.pop(user_id, None)
if session_id:
self.sqlite.close_session(session_id, summary)
logger.info(f"Завершена сессия {session_id} для пользователя {user_id}")
def get_session_id(self, user_id: int) -> Optional[str]:
"""Получить ID текущей сессии."""
if user_id in self._active_sessions:
return self._active_sessions[user_id]
session = self.sqlite.get_active_session(user_id)
if session:
self._active_sessions[user_id] = session.id
return session.id
return self.start_session(user_id)
def add_message(self, user_id: int, role: str, content: str) -> int:
"""Добавить сообщение в оба хранилища."""
from memory_system import Message
session_id = self.get_session_id(user_id)
message = Message(
id=None,
user_id=user_id,
role=role,
content=content,
session_id=session_id
)
# Сохраняем в SQLite
sqlite_id = self.sqlite.save_message(message)
# Сохраняем в векторную БД
if self.vector:
try:
self.vector.add_message(message)
except Exception as e:
logger.error(f"Ошибка сохранения в векторную БД: {e}")
return sqlite_id
def get_context(self, user_id: int, max_messages: int = 10) -> List[Message]:
"""Получить контекст для ИИ (последние сообщения)."""
return self.sqlite.get_recent_messages(user_id, max_messages)
def search_relevant(
self,
user_id: int,
query: str,
max_results: int = 5,
use_vector: bool = True
) -> List[Tuple[Message, float]]:
"""Найти релевантные сообщения."""
# Приоритет векторному поиску
if use_vector and self.vector:
try:
results = self.vector.search_similar(
user_id=user_id,
query=query,
limit=max_results
)
logger.info(f"Векторный поиск: найдено {len(results)} результатов")
return results
except Exception as e:
logger.error(f"Ошибка векторного поиска, используем SQLite: {e}")
# Фоллбэк на SQLite LIKE поиск
messages = self.sqlite.search_messages(user_id, query, max_results)
return [(msg, 0.5) for msg in messages]
def get_user_profile(self, user_id: int) -> Dict[str, List[str]]:
"""Получить профиль пользователя (факты)."""
facts = self.sqlite.get_facts(user_id)
profile = {}
for fact in facts:
type_name = fact.fact_type.value
if type_name not in profile:
profile[type_name] = []
profile[type_name].append(fact.content)
return profile
def extract_and_save_facts(self, user_id: int, message: str, response: str = None):
"""Извлечь факты из сообщения и сохранить."""
import re
from memory_system import Fact, FactType
extracted = []
message_lower = message.lower()
# Имя
if "меня зовут" in message_lower:
parts = message.split("меня зовут")
if len(parts) > 1:
name = parts[1].strip().split()[0]
fact = Fact(
id=None,
user_id=user_id,
fact_type=FactType.PERSONAL,
content=f"Пользователя зовут {name}",
source_message=message,
confidence=0.8
)
self.sqlite.save_fact(fact)
extracted.append(fact)
# Технологии
tech_patterns = [
(r"я (люблю|предпочитаю|использую)\s+(\w+)", FactType.TECHNICAL),
(r"мой (язык|стек)\s+(\w+)", FactType.TECHNICAL),
]
for pattern, fact_type in tech_patterns:
match = re.search(pattern, message_lower)
if match:
tech = match.group(2) if len(match.groups()) > 1 else match.group(1)
fact = Fact(
id=None,
user_id=user_id,
fact_type=fact_type,
content=f"Использует {tech}",
source_message=message,
confidence=0.6
)
self.sqlite.save_fact(fact)
extracted.append(fact)
if extracted:
logger.info(f"Извлечено {len(extracted)} фактов для пользователя {user_id}")
def format_context_for_ai(self, user_id: int, query: str = None) -> str:
"""Сформировать контекст для передачи ИИ."""
parts = []
# Профиль
profile = self.get_user_profile(user_id)
if profile:
parts.append("📋 ПРОФИЛЬ ПОЛЬЗОВАТЕЛЯ:")
for fact_type, facts in profile.items():
parts.append(f" [{fact_type}]:")
for f in facts:
parts.append(f" - {f}")
# Последние сообщения
recent = self.get_context(user_id, 5)
if recent:
parts.append("\n💬 ПОСЛЕДНИЕ СООБЩЕНИЯ:")
for msg in recent:
role_ru = "Пользователь" if msg.role == "user" else "Ассистент"
preview = msg.content[:100].replace('\n', ' ')
parts.append(f" {role_ru}: {preview}...")
# Релевантный поиск
if query:
relevant = self.search_relevant(user_id, query, max_results=3)
if relevant:
parts.append("\n🔍 РЕЛЕВАНТНЫЕ СООБЩЕНИЯ:")
for msg, score in relevant:
preview = msg.content[:100].replace('\n', ' ')
parts.append(f" [{score:.2f}] {preview}...")
return "\n".join(parts)
def get_stats(self, user_id: int) -> Dict[str, Any]:
"""Получить статистику памяти пользователя."""
sqlite_stats = self.sqlite.get_user_stats(user_id)
stats = {
**sqlite_stats,
"hybrid_mode": self.vector is not None
}
if self.vector:
try:
vector_stats = self.vector.get_stats()
stats["vector_documents"] = vector_stats.get("total_documents", 0)
stats["vector_model"] = vector_stats.get("model", "unknown")
except Exception as e:
logger.error(f"Ошибка получения статистики векторной БД: {e}")
stats["vector_documents"] = "N/A"
stats["vector_model"] = "N/A"
return stats
# ============================================================================
# Глобальные экземпляры
# ============================================================================
VECTOR_DB_PATH = str(Path(__file__).parent / "vector_db")
# Создаём гибридный менеджер
sqlite_storage = SQLiteMemoryStorage(MEMORY_DB_PATH)
vector_storage = VectorMemoryStorage(VECTOR_DB_PATH)
hybrid_memory_manager = HybridMemoryManager(
sqlite_storage=sqlite_storage,
vector_storage=vector_storage
)
# ============================================================================
# Хелперы для бота
# ============================================================================
def save_message(user_id: int, role: str, content: str):
"""Сохранить сообщение в гибридную память."""
if hybrid_memory_manager:
hybrid_memory_manager.add_message(user_id, role, content)
if role == "user":
hybrid_memory_manager.extract_and_save_facts(user_id, content)
def get_context(user_id: int, query: str = None) -> str:
"""Получить форматированный контекст для ИИ."""
if hybrid_memory_manager:
return hybrid_memory_manager.format_context_for_ai(user_id, query)
return ""
def search_memory(user_id: int, query: str, limit: int = 5) -> List[Tuple[Message, float]]:
"""Поиск в памяти."""
if hybrid_memory_manager:
return hybrid_memory_manager.search_relevant(user_id, query, limit)
return []
def get_profile(user_id: int) -> Dict[str, List[str]]:
"""Получить профиль пользователя."""
if hybrid_memory_manager:
return hybrid_memory_manager.get_user_profile(user_id)
return {}
def get_memory_stats(user_id: int) -> Dict[str, Any]:
"""Получить статистику памяти."""
if hybrid_memory_manager:
return hybrid_memory_manager.get_stats(user_id)
return {}