ducklm/app/memory/recall.py

206 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import logging
from typing import Any
from app.core.contracts import MemoryEntry
from app.memory.interface import MemoryInterface
from app.models.async_adapters import AsyncOrchestratorAdapter
logger = logging.getLogger(__name__)
RECALL_PROMPT_TEMPLATE = """Определи, нужно ли искать в долговременной памяти для ответа на этот запрос.
Запрос: "{task_input}"
ИСКАТЬ в памяти если запрос:
- Содержит вопрос о пользователе (имя, предпочтения, история)
- Содержит отсылки к прошлым разговорам или действиям
- Содержит местоимения без контекста ("он", "это", "тот файл")
- Просит вспомнить, повторить, рассказать о прошлом
- Спрашивает "что ты помнишь", "как меня зовут", "что я говорил"
НЕ ИСКАТЬ если:
- Приветствие или прощание
- Простая команда (ls, pwd, echo)
- Общий вопрос не связанный с прошлым
Ответь ТОЛЬКО JSON:
{{"should_recall": true, "search_query": "поисковый запрос"}}
или
{{"should_recall": false, "reason": "краткая причина"}}"""
class MemoryRecallService:
"""Активное воспоминание: система сама решает, что и когда искать в памяти."""
def __init__(
self,
memory_interface: MemoryInterface | None,
recall_model: AsyncOrchestratorAdapter | None,
) -> None:
self._memory = memory_interface
self._model = recall_model
async def recall(
self,
task_input: str,
top_k: int = 5,
) -> dict[str, Any]:
"""
Определяет необходимость воспоминания и выполняет поиск.
Возвращает:
{
"should_recall": bool,
"reason": str,
"query": str,
"results": list[MemoryEntry],
"summary": str, # краткая сводка для оркестратора
}
"""
if not self._memory or not self._model:
with open("/tmp/recall_debug.log", "a") as f:
f.write(f"SKIP: memory={self._memory is not None}, model={self._model is not None}\n")
return self._empty_result("memory_or_model_unavailable")
# 1. LLM решает, нужно ли искать
decision = await self._classify(task_input)
with open("/tmp/recall_debug.log", "a") as f:
f.write(f"DECISION type={type(decision)} value={decision}\n")
if not isinstance(decision, dict):
return self._empty_result("invalid_decision_type")
if not decision.get("should_recall"):
return self._empty_result(decision.get("reason", "not_needed"))
search_query = decision.get("search_query", task_input)
logger.info(f"Memory recall: query='{search_query}', reason='{decision.get('reason')}'")
# 2. Векторный поиск
try:
raw_results = self._memory.search(query=search_query, top_k=top_k)
except Exception as e:
logger.warning(f"Memory search failed: {e}")
return self._empty_result("search_failed")
# 3. Фильтрация: убираем пустые и слишком нерелевантные
filtered = self._filter(raw_results)
if not filtered:
return self._empty_result("no_relevant_results")
# 4. Сводка для оркестратора
summary = self._summarize(filtered, search_query)
return {
"should_recall": True,
"reason": decision.get("reason", ""),
"query": search_query,
"results": filtered,
"summary": summary,
}
async def _classify(self, task_input: str) -> dict[str, Any]:
"""LLM-классификация: нужно ли искать в памяти."""
prompt = RECALL_PROMPT_TEMPLATE.format(task_input=task_input)
try:
raw = await self._model.generate(prompt, max_tokens=512)
data = self._parse_json(raw)
if "should_recall" in data:
return data
logger.warning(f"Recall classification missing 'should_recall': {raw[:200]}")
return {"should_recall": False, "reason": "parse_error"}
except Exception as e:
logger.warning(f"Recall classification failed: {e}")
return {"should_recall": False, "reason": "classification_error"}
def _filter(
self,
results: list[tuple[MemoryEntry, float]],
min_score: float = 0.3,
) -> list[MemoryEntry]:
"""Фильтрует результаты по score и убирает дубликаты."""
seen_texts: set[str] = set()
filtered: list[MemoryEntry] = []
for entry, score in results:
if score < min_score:
continue
# Нормализуем текст для дедупликации
normalized = entry.text.strip().lower()[:100]
if normalized in seen_texts:
continue
seen_texts.add(normalized)
filtered.append(entry)
return filtered
def _summarize(
self,
results: list[MemoryEntry],
query: str,
) -> str:
"""Краткая сводка найденного для оркестратора."""
parts = [f"По запросу '{query}' найдено {len(results)} записей:"]
for i, entry in enumerate(results[:5], 1):
text_preview = entry.text[:120].replace("\n", " ")
parts.append(f" {i}. [{entry.kind}] {text_preview}")
return "\n".join(parts)
def _parse_json(self, raw: str) -> dict[str, Any]:
"""Извлекает JSON из ответа модели, пропуская рассуждения перед ним."""
try:
json_start = raw.find("{")
json_end = raw.rfind("}") + 1
if json_start < 0 or json_end <= 0:
return {}
# Пробуем весь текст от первого { до последнего }
try:
data = json.loads(raw[json_start:json_end])
if isinstance(data, dict):
return data
except json.JSONDecodeError:
pass
# Ищем все возможные начала JSON
candidates = []
pos = 0
while True:
pos = raw.find("{", pos)
if pos < 0:
break
candidates.append(pos)
pos += 1
# Пробуем каждый candidate с конца
for start in reversed(candidates):
end = raw.rfind("}") + 1
if end <= start:
continue
try:
data = json.loads(raw[start:end])
if isinstance(data, dict):
return data
except json.JSONDecodeError:
continue
return {}
except Exception as e:
with open("/tmp/recall_debug.log", "a") as f:
f.write(f"PARSE ERROR: {e}\n")
return {}
@staticmethod
def _empty_result(reason: str) -> dict[str, Any]:
return {
"should_recall": False,
"reason": reason,
"query": "",
"results": [],
"summary": "",
}