diff --git a/src/bot/main.py b/src/bot/main.py index 4686b18..18d6597 100644 --- a/src/bot/main.py +++ b/src/bot/main.py @@ -4,6 +4,7 @@ from telegram import Update from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from config.config import get_settings from src.tools.tool_runner import ToolRunner +from src.memory.memory import Memory logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -13,6 +14,7 @@ logger = logging.getLogger(__name__) settings = get_settings() tool_runner = ToolRunner() +memory = Memory() async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): @@ -30,6 +32,7 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE): "/help - Показать эту справку\n" "/qwen <текст> - Задать вопрос qwen-code\n" "/open <текст> - Задать вопрос opencode\n" + "/forget - Очистить историю чата\n" ) await update.message.reply_text(help_text) @@ -40,8 +43,13 @@ async def qwen_command(update: Update, context: ContextTypes.DEFAULT_TYPE): await update.message.reply_text("Использование: /qwen <текст>") return + chat_id = update.effective_chat.id + memory.add_message(chat_id, "user", prompt) + await update.message.reply_text("Думаю...") result, success = await tool_runner.run_qwen(prompt) + + memory.add_message(chat_id, "assistant", result) await update.message.reply_text(result[:4096] if len(result) > 4096 else result) @@ -51,11 +59,22 @@ async def open_command(update: Update, context: ContextTypes.DEFAULT_TYPE): await update.message.reply_text("Использование: /open <текст>") return + chat_id = update.effective_chat.id + memory.add_message(chat_id, "user", prompt) + await update.message.reply_text("Думаю...") result, success = await tool_runner.run_opencode(prompt) + + memory.add_message(chat_id, "assistant", result) await update.message.reply_text(result[:4096] if len(result) > 4096 else result) +async def forget_command(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + memory.clear_chat(chat_id) + await update.message.reply_text("История чата очищена.") + + async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE): await update.message.reply_text(update.message.text) @@ -74,6 +93,7 @@ def main(): application.add_handler(CommandHandler("help", help_command)) application.add_handler(CommandHandler("qwen", qwen_command)) application.add_handler(CommandHandler("open", open_command)) + application.add_handler(CommandHandler("forget", forget_command)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) logger.info("Бот запущен") diff --git a/src/memory/__init__.py b/src/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/memory/memory.py b/src/memory/memory.py new file mode 100644 index 0000000..4c124fc --- /dev/null +++ b/src/memory/memory.py @@ -0,0 +1,102 @@ +import logging +from typing import List, Dict, Optional +import chromadb +from chromadb.config import Settings as ChromaSettings +from sentence_transformers import SentenceTransformer +from config.config import get_settings + +logger = logging.getLogger(__name__) +settings = get_settings() + + +class Memory: + def __init__(self): + self.client = chromadb.PersistentClient( + path=settings.chroma_persist_dir, + settings=ChromaSettings(anonymized_telemetry=False) + ) + self.collection = self.client.get_or_create_collection( + name="chat_history", + metadata={"hnsw:space": "cosine"} + ) + try: + self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') + except Exception as e: + logger.warning(f"Не удалось загрузить модель эмбеддингов: {e}") + self.embedding_model = None + + def _get_chat_id(self, chat_id: int) -> str: + return str(chat_id) + + def add_message(self, chat_id: int, role: str, content: str): + doc_id = f"{chat_id}_{self.collection.count()}" + self.collection.add( + documents=[content], + ids=[doc_id], + metadatas=[{"chat_id": str(chat_id), "role": role}] + ) + logger.info(f"Добавлено сообщение в чат {chat_id}: {role}") + + def get_recent_messages(self, chat_id: int, limit: int = None) -> List[Dict]: + if limit is None: + limit = settings.memory_messages_count + + chat_id_str = self._get_chat_id(chat_id) + + results = self.collection.get( + where={"chat_id": chat_id_str} + ) + + if not results or not results.get("ids"): + return [] + + messages = [] + for i, doc_id in enumerate(results["ids"]): + messages.append({ + "role": results["metadatas"][i].get("role", "user"), + "content": results["documents"][i] + }) + + messages.sort(key=lambda x: x.get("timestamp", 0)) + return messages[-limit:] + + def search_similar(self, chat_id: int, query: str, limit: int = 3) -> List[str]: + if not self.embedding_model: + return [] + + try: + query_embedding = self.embedding_model.encode([query]) + + results = self.collection.query( + query_embeddings=query_embedding.tolist(), + n_results=limit, + where={"chat_id": str(chat_id)} + ) + + if results and results.get("documents"): + return results["documents"][0] + except Exception as e: + logger.error(f"Ошибка поиска: {e}") + + return [] + + def clear_chat(self, chat_id: int): + chat_id_str = self._get_chat_id(chat_id) + + results = self.collection.get(where={"chat_id": chat_id_str}) + + if results and results.get("ids"): + self.collection.delete(ids=results["ids"]) + logger.info(f"Очищена история чата {chat_id}") + + def get_context_for_prompt(self, chat_id: int) -> str: + messages = self.get_recent_messages(chat_id) + if not messages: + return "" + + context_parts = [] + for msg in messages: + role = "Пользователь" if msg["role"] == "user" else "Ассистент" + context_parts.append(f"{role}: {msg['content']}") + + return "\n".join(context_parts)