Этап 3: Память на ChromaDB
This commit is contained in:
parent
85e702ce25
commit
6c2f17e37a
|
|
@ -4,6 +4,7 @@ from telegram import Update
|
|||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from config.config import get_settings
|
||||
from src.tools.tool_runner import ToolRunner
|
||||
from src.memory.memory import Memory
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
|
|
@ -13,6 +14,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
settings = get_settings()
|
||||
tool_runner = ToolRunner()
|
||||
memory = Memory()
|
||||
|
||||
|
||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
|
|
@ -30,6 +32,7 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|||
"/help - Показать эту справку\n"
|
||||
"/qwen <текст> - Задать вопрос qwen-code\n"
|
||||
"/open <текст> - Задать вопрос opencode\n"
|
||||
"/forget - Очистить историю чата\n"
|
||||
)
|
||||
await update.message.reply_text(help_text)
|
||||
|
||||
|
|
@ -40,8 +43,13 @@ async def qwen_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|||
await update.message.reply_text("Использование: /qwen <текст>")
|
||||
return
|
||||
|
||||
chat_id = update.effective_chat.id
|
||||
memory.add_message(chat_id, "user", prompt)
|
||||
|
||||
await update.message.reply_text("Думаю...")
|
||||
result, success = await tool_runner.run_qwen(prompt)
|
||||
|
||||
memory.add_message(chat_id, "assistant", result)
|
||||
await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
|
||||
|
||||
|
||||
|
|
@ -51,11 +59,22 @@ async def open_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|||
await update.message.reply_text("Использование: /open <текст>")
|
||||
return
|
||||
|
||||
chat_id = update.effective_chat.id
|
||||
memory.add_message(chat_id, "user", prompt)
|
||||
|
||||
await update.message.reply_text("Думаю...")
|
||||
result, success = await tool_runner.run_opencode(prompt)
|
||||
|
||||
memory.add_message(chat_id, "assistant", result)
|
||||
await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
|
||||
|
||||
|
||||
async def forget_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
chat_id = update.effective_chat.id
|
||||
memory.clear_chat(chat_id)
|
||||
await update.message.reply_text("История чата очищена.")
|
||||
|
||||
|
||||
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
await update.message.reply_text(update.message.text)
|
||||
|
||||
|
|
@ -74,6 +93,7 @@ def main():
|
|||
application.add_handler(CommandHandler("help", help_command))
|
||||
application.add_handler(CommandHandler("qwen", qwen_command))
|
||||
application.add_handler(CommandHandler("open", open_command))
|
||||
application.add_handler(CommandHandler("forget", forget_command))
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
||||
|
||||
logger.info("Бот запущен")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,102 @@
|
|||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from config.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class Memory:
|
||||
def __init__(self):
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=settings.chroma_persist_dir,
|
||||
settings=ChromaSettings(anonymized_telemetry=False)
|
||||
)
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name="chat_history",
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
try:
|
||||
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
except Exception as e:
|
||||
logger.warning(f"Не удалось загрузить модель эмбеддингов: {e}")
|
||||
self.embedding_model = None
|
||||
|
||||
def _get_chat_id(self, chat_id: int) -> str:
|
||||
return str(chat_id)
|
||||
|
||||
def add_message(self, chat_id: int, role: str, content: str):
|
||||
doc_id = f"{chat_id}_{self.collection.count()}"
|
||||
self.collection.add(
|
||||
documents=[content],
|
||||
ids=[doc_id],
|
||||
metadatas=[{"chat_id": str(chat_id), "role": role}]
|
||||
)
|
||||
logger.info(f"Добавлено сообщение в чат {chat_id}: {role}")
|
||||
|
||||
def get_recent_messages(self, chat_id: int, limit: int = None) -> List[Dict]:
|
||||
if limit is None:
|
||||
limit = settings.memory_messages_count
|
||||
|
||||
chat_id_str = self._get_chat_id(chat_id)
|
||||
|
||||
results = self.collection.get(
|
||||
where={"chat_id": chat_id_str}
|
||||
)
|
||||
|
||||
if not results or not results.get("ids"):
|
||||
return []
|
||||
|
||||
messages = []
|
||||
for i, doc_id in enumerate(results["ids"]):
|
||||
messages.append({
|
||||
"role": results["metadatas"][i].get("role", "user"),
|
||||
"content": results["documents"][i]
|
||||
})
|
||||
|
||||
messages.sort(key=lambda x: x.get("timestamp", 0))
|
||||
return messages[-limit:]
|
||||
|
||||
def search_similar(self, chat_id: int, query: str, limit: int = 3) -> List[str]:
|
||||
if not self.embedding_model:
|
||||
return []
|
||||
|
||||
try:
|
||||
query_embedding = self.embedding_model.encode([query])
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=limit,
|
||||
where={"chat_id": str(chat_id)}
|
||||
)
|
||||
|
||||
if results and results.get("documents"):
|
||||
return results["documents"][0]
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка поиска: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def clear_chat(self, chat_id: int):
|
||||
chat_id_str = self._get_chat_id(chat_id)
|
||||
|
||||
results = self.collection.get(where={"chat_id": chat_id_str})
|
||||
|
||||
if results and results.get("ids"):
|
||||
self.collection.delete(ids=results["ids"])
|
||||
logger.info(f"Очищена история чата {chat_id}")
|
||||
|
||||
def get_context_for_prompt(self, chat_id: int) -> str:
|
||||
messages = self.get_recent_messages(chat_id)
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
context_parts = []
|
||||
for msg in messages:
|
||||
role = "Пользователь" if msg["role"] == "user" else "Ассистент"
|
||||
context_parts.append(f"{role}: {msg['content']}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
Loading…
Reference in New Issue