Этап 3: Память на ChromaDB
This commit is contained in:
parent
85e702ce25
commit
6c2f17e37a
|
|
@ -4,6 +4,7 @@ from telegram import Update
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||||
from config.config import get_settings
|
from config.config import get_settings
|
||||||
from src.tools.tool_runner import ToolRunner
|
from src.tools.tool_runner import ToolRunner
|
||||||
|
from src.memory.memory import Memory
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
|
@ -13,6 +14,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
tool_runner = ToolRunner()
|
tool_runner = ToolRunner()
|
||||||
|
memory = Memory()
|
||||||
|
|
||||||
|
|
||||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
|
@ -30,6 +32,7 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
"/help - Показать эту справку\n"
|
"/help - Показать эту справку\n"
|
||||||
"/qwen <текст> - Задать вопрос qwen-code\n"
|
"/qwen <текст> - Задать вопрос qwen-code\n"
|
||||||
"/open <текст> - Задать вопрос opencode\n"
|
"/open <текст> - Задать вопрос opencode\n"
|
||||||
|
"/forget - Очистить историю чата\n"
|
||||||
)
|
)
|
||||||
await update.message.reply_text(help_text)
|
await update.message.reply_text(help_text)
|
||||||
|
|
||||||
|
|
@ -40,8 +43,13 @@ async def qwen_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
await update.message.reply_text("Использование: /qwen <текст>")
|
await update.message.reply_text("Использование: /qwen <текст>")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
chat_id = update.effective_chat.id
|
||||||
|
memory.add_message(chat_id, "user", prompt)
|
||||||
|
|
||||||
await update.message.reply_text("Думаю...")
|
await update.message.reply_text("Думаю...")
|
||||||
result, success = await tool_runner.run_qwen(prompt)
|
result, success = await tool_runner.run_qwen(prompt)
|
||||||
|
|
||||||
|
memory.add_message(chat_id, "assistant", result)
|
||||||
await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
|
await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,11 +59,22 @@ async def open_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
await update.message.reply_text("Использование: /open <текст>")
|
await update.message.reply_text("Использование: /open <текст>")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
chat_id = update.effective_chat.id
|
||||||
|
memory.add_message(chat_id, "user", prompt)
|
||||||
|
|
||||||
await update.message.reply_text("Думаю...")
|
await update.message.reply_text("Думаю...")
|
||||||
result, success = await tool_runner.run_opencode(prompt)
|
result, success = await tool_runner.run_opencode(prompt)
|
||||||
|
|
||||||
|
memory.add_message(chat_id, "assistant", result)
|
||||||
await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
|
await update.message.reply_text(result[:4096] if len(result) > 4096 else result)
|
||||||
|
|
||||||
|
|
||||||
|
async def forget_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
chat_id = update.effective_chat.id
|
||||||
|
memory.clear_chat(chat_id)
|
||||||
|
await update.message.reply_text("История чата очищена.")
|
||||||
|
|
||||||
|
|
||||||
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
await update.message.reply_text(update.message.text)
|
await update.message.reply_text(update.message.text)
|
||||||
|
|
||||||
|
|
@ -74,6 +93,7 @@ def main():
|
||||||
application.add_handler(CommandHandler("help", help_command))
|
application.add_handler(CommandHandler("help", help_command))
|
||||||
application.add_handler(CommandHandler("qwen", qwen_command))
|
application.add_handler(CommandHandler("qwen", qwen_command))
|
||||||
application.add_handler(CommandHandler("open", open_command))
|
application.add_handler(CommandHandler("open", open_command))
|
||||||
|
application.add_handler(CommandHandler("forget", forget_command))
|
||||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
||||||
|
|
||||||
logger.info("Бот запущен")
|
logger.info("Бот запущен")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings as ChromaSettings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from config.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
class Memory:
|
||||||
|
def __init__(self):
|
||||||
|
self.client = chromadb.PersistentClient(
|
||||||
|
path=settings.chroma_persist_dir,
|
||||||
|
settings=ChromaSettings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name="chat_history",
|
||||||
|
metadata={"hnsw:space": "cosine"}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Не удалось загрузить модель эмбеддингов: {e}")
|
||||||
|
self.embedding_model = None
|
||||||
|
|
||||||
|
def _get_chat_id(self, chat_id: int) -> str:
|
||||||
|
return str(chat_id)
|
||||||
|
|
||||||
|
def add_message(self, chat_id: int, role: str, content: str):
|
||||||
|
doc_id = f"{chat_id}_{self.collection.count()}"
|
||||||
|
self.collection.add(
|
||||||
|
documents=[content],
|
||||||
|
ids=[doc_id],
|
||||||
|
metadatas=[{"chat_id": str(chat_id), "role": role}]
|
||||||
|
)
|
||||||
|
logger.info(f"Добавлено сообщение в чат {chat_id}: {role}")
|
||||||
|
|
||||||
|
def get_recent_messages(self, chat_id: int, limit: int = None) -> List[Dict]:
|
||||||
|
if limit is None:
|
||||||
|
limit = settings.memory_messages_count
|
||||||
|
|
||||||
|
chat_id_str = self._get_chat_id(chat_id)
|
||||||
|
|
||||||
|
results = self.collection.get(
|
||||||
|
where={"chat_id": chat_id_str}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results or not results.get("ids"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for i, doc_id in enumerate(results["ids"]):
|
||||||
|
messages.append({
|
||||||
|
"role": results["metadatas"][i].get("role", "user"),
|
||||||
|
"content": results["documents"][i]
|
||||||
|
})
|
||||||
|
|
||||||
|
messages.sort(key=lambda x: x.get("timestamp", 0))
|
||||||
|
return messages[-limit:]
|
||||||
|
|
||||||
|
def search_similar(self, chat_id: int, query: str, limit: int = 3) -> List[str]:
|
||||||
|
if not self.embedding_model:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
query_embedding = self.embedding_model.encode([query])
|
||||||
|
|
||||||
|
results = self.collection.query(
|
||||||
|
query_embeddings=query_embedding.tolist(),
|
||||||
|
n_results=limit,
|
||||||
|
where={"chat_id": str(chat_id)}
|
||||||
|
)
|
||||||
|
|
||||||
|
if results and results.get("documents"):
|
||||||
|
return results["documents"][0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ошибка поиска: {e}")
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_chat(self, chat_id: int):
|
||||||
|
chat_id_str = self._get_chat_id(chat_id)
|
||||||
|
|
||||||
|
results = self.collection.get(where={"chat_id": chat_id_str})
|
||||||
|
|
||||||
|
if results and results.get("ids"):
|
||||||
|
self.collection.delete(ids=results["ids"])
|
||||||
|
logger.info(f"Очищена история чата {chat_id}")
|
||||||
|
|
||||||
|
def get_context_for_prompt(self, chat_id: int) -> str:
|
||||||
|
messages = self.get_recent_messages(chat_id)
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
context_parts = []
|
||||||
|
for msg in messages:
|
||||||
|
role = "Пользователь" if msg["role"] == "user" else "Ассистент"
|
||||||
|
context_parts.append(f"{role}: {msg['content']}")
|
||||||
|
|
||||||
|
return "\n".join(context_parts)
|
||||||
Loading…
Reference in New Issue