206 lines
7.8 KiB
Python
206 lines
7.8 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from typing import Any
|
||
|
||
from app.core.contracts import MemoryEntry
|
||
from app.memory.interface import MemoryInterface
|
||
from app.models.async_adapters import AsyncOrchestratorAdapter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
RECALL_PROMPT_TEMPLATE = """Определи, нужно ли искать в долговременной памяти для ответа на этот запрос.
|
||
|
||
Запрос: "{task_input}"
|
||
|
||
ИСКАТЬ в памяти если запрос:
|
||
- Содержит вопрос о пользователе (имя, предпочтения, история)
|
||
- Содержит отсылки к прошлым разговорам или действиям
|
||
- Содержит местоимения без контекста ("он", "это", "тот файл")
|
||
- Просит вспомнить, повторить, рассказать о прошлом
|
||
- Спрашивает "что ты помнишь", "как меня зовут", "что я говорил"
|
||
|
||
НЕ ИСКАТЬ если:
|
||
- Приветствие или прощание
|
||
- Простая команда (ls, pwd, echo)
|
||
- Общий вопрос не связанный с прошлым
|
||
|
||
Ответь ТОЛЬКО JSON:
|
||
{{"should_recall": true, "search_query": "поисковый запрос"}}
|
||
или
|
||
{{"should_recall": false, "reason": "краткая причина"}}"""
|
||
|
||
|
||
class MemoryRecallService:
|
||
"""Активное воспоминание: система сама решает, что и когда искать в памяти."""
|
||
|
||
def __init__(
|
||
self,
|
||
memory_interface: MemoryInterface | None,
|
||
recall_model: AsyncOrchestratorAdapter | None,
|
||
) -> None:
|
||
self._memory = memory_interface
|
||
self._model = recall_model
|
||
|
||
async def recall(
|
||
self,
|
||
task_input: str,
|
||
top_k: int = 5,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
Определяет необходимость воспоминания и выполняет поиск.
|
||
|
||
Возвращает:
|
||
{
|
||
"should_recall": bool,
|
||
"reason": str,
|
||
"query": str,
|
||
"results": list[MemoryEntry],
|
||
"summary": str, # краткая сводка для оркестратора
|
||
}
|
||
"""
|
||
if not self._memory or not self._model:
|
||
with open("/tmp/recall_debug.log", "a") as f:
|
||
f.write(f"SKIP: memory={self._memory is not None}, model={self._model is not None}\n")
|
||
return self._empty_result("memory_or_model_unavailable")
|
||
|
||
# 1. LLM решает, нужно ли искать
|
||
decision = await self._classify(task_input)
|
||
with open("/tmp/recall_debug.log", "a") as f:
|
||
f.write(f"DECISION type={type(decision)} value={decision}\n")
|
||
if not isinstance(decision, dict):
|
||
return self._empty_result("invalid_decision_type")
|
||
if not decision.get("should_recall"):
|
||
return self._empty_result(decision.get("reason", "not_needed"))
|
||
|
||
search_query = decision.get("search_query", task_input)
|
||
logger.info(f"Memory recall: query='{search_query}', reason='{decision.get('reason')}'")
|
||
|
||
# 2. Векторный поиск
|
||
try:
|
||
raw_results = self._memory.search(query=search_query, top_k=top_k)
|
||
except Exception as e:
|
||
logger.warning(f"Memory search failed: {e}")
|
||
return self._empty_result("search_failed")
|
||
|
||
# 3. Фильтрация: убираем пустые и слишком нерелевантные
|
||
filtered = self._filter(raw_results)
|
||
|
||
if not filtered:
|
||
return self._empty_result("no_relevant_results")
|
||
|
||
# 4. Сводка для оркестратора
|
||
summary = self._summarize(filtered, search_query)
|
||
|
||
return {
|
||
"should_recall": True,
|
||
"reason": decision.get("reason", ""),
|
||
"query": search_query,
|
||
"results": filtered,
|
||
"summary": summary,
|
||
}
|
||
|
||
async def _classify(self, task_input: str) -> dict[str, Any]:
|
||
"""LLM-классификация: нужно ли искать в памяти."""
|
||
prompt = RECALL_PROMPT_TEMPLATE.format(task_input=task_input)
|
||
|
||
try:
|
||
raw = await self._model.generate(prompt, max_tokens=512)
|
||
data = self._parse_json(raw)
|
||
if "should_recall" in data:
|
||
return data
|
||
logger.warning(f"Recall classification missing 'should_recall': {raw[:200]}")
|
||
return {"should_recall": False, "reason": "parse_error"}
|
||
except Exception as e:
|
||
logger.warning(f"Recall classification failed: {e}")
|
||
return {"should_recall": False, "reason": "classification_error"}
|
||
|
||
def _filter(
|
||
self,
|
||
results: list[tuple[MemoryEntry, float]],
|
||
min_score: float = 0.3,
|
||
) -> list[MemoryEntry]:
|
||
"""Фильтрует результаты по score и убирает дубликаты."""
|
||
seen_texts: set[str] = set()
|
||
filtered: list[MemoryEntry] = []
|
||
|
||
for entry, score in results:
|
||
if score < min_score:
|
||
continue
|
||
# Нормализуем текст для дедупликации
|
||
normalized = entry.text.strip().lower()[:100]
|
||
if normalized in seen_texts:
|
||
continue
|
||
seen_texts.add(normalized)
|
||
filtered.append(entry)
|
||
|
||
return filtered
|
||
|
||
def _summarize(
|
||
self,
|
||
results: list[MemoryEntry],
|
||
query: str,
|
||
) -> str:
|
||
"""Краткая сводка найденного для оркестратора."""
|
||
parts = [f"По запросу '{query}' найдено {len(results)} записей:"]
|
||
for i, entry in enumerate(results[:5], 1):
|
||
text_preview = entry.text[:120].replace("\n", " ")
|
||
parts.append(f" {i}. [{entry.kind}] {text_preview}")
|
||
return "\n".join(parts)
|
||
|
||
def _parse_json(self, raw: str) -> dict[str, Any]:
|
||
"""Извлекает JSON из ответа модели, пропуская рассуждения перед ним."""
|
||
try:
|
||
json_start = raw.find("{")
|
||
json_end = raw.rfind("}") + 1
|
||
|
||
if json_start < 0 or json_end <= 0:
|
||
return {}
|
||
|
||
# Пробуем весь текст от первого { до последнего }
|
||
try:
|
||
data = json.loads(raw[json_start:json_end])
|
||
if isinstance(data, dict):
|
||
return data
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Ищем все возможные начала JSON
|
||
candidates = []
|
||
pos = 0
|
||
while True:
|
||
pos = raw.find("{", pos)
|
||
if pos < 0:
|
||
break
|
||
candidates.append(pos)
|
||
pos += 1
|
||
|
||
# Пробуем каждый candidate с конца
|
||
for start in reversed(candidates):
|
||
end = raw.rfind("}") + 1
|
||
if end <= start:
|
||
continue
|
||
try:
|
||
data = json.loads(raw[start:end])
|
||
if isinstance(data, dict):
|
||
return data
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
return {}
|
||
except Exception as e:
|
||
with open("/tmp/recall_debug.log", "a") as f:
|
||
f.write(f"PARSE ERROR: {e}\n")
|
||
return {}
|
||
|
||
@staticmethod
|
||
def _empty_result(reason: str) -> dict[str, Any]:
|
||
return {
|
||
"should_recall": False,
|
||
"reason": reason,
|
||
"query": "",
|
||
"results": [],
|
||
"summary": "",
|
||
}
|