Этап 3: Память на ChromaDB

This commit is contained in:
mirivlad 2026-03-17 03:22:56 +08:00
parent 85e702ce25
commit 6c2f17e37a
3 changed files with 122 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
from config.config import get_settings from config.config import get_settings
from src.tools.tool_runner import ToolRunner from src.tools.tool_runner import ToolRunner
from src.memory.memory import Memory
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@ -13,6 +14,7 @@ logger = logging.getLogger(__name__)
settings = get_settings() settings = get_settings()
tool_runner = ToolRunner() tool_runner = ToolRunner()
memory = Memory()
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
@ -30,6 +32,7 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
"/help - Показать эту справку\n" "/help - Показать эту справку\n"
"/qwen <текст> - Задать вопрос qwen-code\n" "/qwen <текст> - Задать вопрос qwen-code\n"
"/open <текст> - Задать вопрос opencode\n" "/open <текст> - Задать вопрос opencode\n"
"/forget - Очистить историю чата\n"
) )
await update.message.reply_text(help_text) await update.message.reply_text(help_text)
@ -40,8 +43,13 @@ async def qwen_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
await update.message.reply_text("Использование: /qwen <текст>") await update.message.reply_text("Использование: /qwen <текст>")
return return
chat_id = update.effective_chat.id
memory.add_message(chat_id, "user", prompt)
await update.message.reply_text("Думаю...") await update.message.reply_text("Думаю...")
result, success = await tool_runner.run_qwen(prompt) result, success = await tool_runner.run_qwen(prompt)
memory.add_message(chat_id, "assistant", result)
await update.message.reply_text(result[:4096] if len(result) > 4096 else result) await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
@ -51,11 +59,22 @@ async def open_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
await update.message.reply_text("Использование: /open <текст>") await update.message.reply_text("Использование: /open <текст>")
return return
chat_id = update.effective_chat.id
memory.add_message(chat_id, "user", prompt)
await update.message.reply_text("Думаю...") await update.message.reply_text("Думаю...")
result, success = await tool_runner.run_opencode(prompt) result, success = await tool_runner.run_opencode(prompt)
memory.add_message(chat_id, "assistant", result)
await update.message.reply_text(result[:4096] if len(result) > 4096 else result) await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
async def forget_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
chat_id = update.effective_chat.id
memory.clear_chat(chat_id)
await update.message.reply_text("История чата очищена.")
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE): async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE):
await update.message.reply_text(update.message.text) await update.message.reply_text(update.message.text)
@ -74,6 +93,7 @@ def main():
application.add_handler(CommandHandler("help", help_command)) application.add_handler(CommandHandler("help", help_command))
application.add_handler(CommandHandler("qwen", qwen_command)) application.add_handler(CommandHandler("qwen", qwen_command))
application.add_handler(CommandHandler("open", open_command)) application.add_handler(CommandHandler("open", open_command))
application.add_handler(CommandHandler("forget", forget_command))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
logger.info("Бот запущен") logger.info("Бот запущен")

0
src/memory/__init__.py Normal file
View File

102
src/memory/memory.py Normal file
View File

@ -0,0 +1,102 @@
import logging
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from sentence_transformers import SentenceTransformer
from config.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class Memory:
def __init__(self):
self.client = chromadb.PersistentClient(
path=settings.chroma_persist_dir,
settings=ChromaSettings(anonymized_telemetry=False)
)
self.collection = self.client.get_or_create_collection(
name="chat_history",
metadata={"hnsw:space": "cosine"}
)
try:
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
logger.warning(f"Не удалось загрузить модель эмбеддингов: {e}")
self.embedding_model = None
def _get_chat_id(self, chat_id: int) -> str:
return str(chat_id)
def add_message(self, chat_id: int, role: str, content: str):
doc_id = f"{chat_id}_{self.collection.count()}"
self.collection.add(
documents=[content],
ids=[doc_id],
metadatas=[{"chat_id": str(chat_id), "role": role}]
)
logger.info(f"Добавлено сообщение в чат {chat_id}: {role}")
def get_recent_messages(self, chat_id: int, limit: int = None) -> List[Dict]:
if limit is None:
limit = settings.memory_messages_count
chat_id_str = self._get_chat_id(chat_id)
results = self.collection.get(
where={"chat_id": chat_id_str}
)
if not results or not results.get("ids"):
return []
messages = []
for i, doc_id in enumerate(results["ids"]):
messages.append({
"role": results["metadatas"][i].get("role", "user"),
"content": results["documents"][i]
})
messages.sort(key=lambda x: x.get("timestamp", 0))
return messages[-limit:]
def search_similar(self, chat_id: int, query: str, limit: int = 3) -> List[str]:
if not self.embedding_model:
return []
try:
query_embedding = self.embedding_model.encode([query])
results = self.collection.query(
query_embeddings=query_embedding.tolist(),
n_results=limit,
where={"chat_id": str(chat_id)}
)
if results and results.get("documents"):
return results["documents"][0]
except Exception as e:
logger.error(f"Ошибка поиска: {e}")
return []
def clear_chat(self, chat_id: int):
chat_id_str = self._get_chat_id(chat_id)
results = self.collection.get(where={"chat_id": chat_id_str})
if results and results.get("ids"):
self.collection.delete(ids=results["ids"])
logger.info(f"Очищена история чата {chat_id}")
def get_context_for_prompt(self, chat_id: int) -> str:
messages = self.get_recent_messages(chat_id)
if not messages:
return ""
context_parts = []
for msg in messages:
role = "Пользователь" if msg["role"] == "user" else "Ассистент"
context_parts.append(f"{role}: {msg['content']}")
return "\n".join(context_parts)