from __future__ import annotations import json from datetime import datetime, timezone from pathlib import Path from typing import Any, Literal import numpy as np from app.core.contracts import MemoryEntry from app.memory.store import MemoryStore from app.memory.vector_index import VectorIndex from app.models.embeddings import EmbeddingsAdapter class MemoryInterface: def __init__( self, store: MemoryStore, vector_index: VectorIndex, embeddings: EmbeddingsAdapter, ) -> None: self._store = store self._vector_index = vector_index self._embeddings = embeddings def insert( self, text: str, kind: Literal["tool_result", "plan", "critique", "fact", "summary", "user_preference"], source: Literal["tool", "critic", "user", "system"], task_id: str | None = None, session_id: str | None = None, weight: float = 0.5, metadata: dict[str, Any] | None = None, ) -> MemoryEntry: entry = MemoryEntry( text=text, kind=kind, source=source, weight=weight, task_id=task_id, session_id=session_id, metadata=metadata or {}, embedding_model=self._embeddings.__class__.__name__, embedding_dim=self._embeddings.embedding_dim, ) embedding = self._embeddings.encode(text) embedding_bytes = embedding.astype("float32").tobytes() self._store.insert(entry, embedding_bytes) self._vector_index.insert(entry.id, embedding) self._vector_index.save() self.cleanup() return entry def search( self, query: str, top_k: int = 5, kind: str | None = None, session_id: str | None = None, ) -> list[tuple[MemoryEntry, float]]: query_embedding = self._embeddings.encode(query) memory_ids, scores = self._vector_index.search(query_embedding, k=top_k) results: list[tuple[MemoryEntry, float]] = [] for memory_id, score in zip(memory_ids, scores): entry = self._store.get(memory_id) if entry: if kind and entry.kind != kind: continue if session_id and entry.session_id != session_id: continue results.append((entry, score)) return results[:top_k] def get(self, memory_id: str) -> MemoryEntry | None: return self._store.get(memory_id) def delete(self, memory_id: str) -> bool: entry = self._store.get(memory_id) if entry: self._vector_index.delete(memory_id) return self._store.delete(memory_id) return False def get_by_task(self, task_id: str) -> list[MemoryEntry]: return self._store.get_by_task(task_id) def get_by_session(self, session_id: str, limit: int = 100) -> list[MemoryEntry]: return self._store.get_by_session(session_id, limit) def get_recent(self, limit: int = 10) -> list[MemoryEntry]: return self._store.get_all(limit) def count(self) -> int: return self._store.count() def reindex(self) -> None: entries = self._store.get_all(limit=10000) self._vector_index.save() for entry in entries: text = entry.text embedding = self._embeddings.encode(text) self._vector_index.insert(entry.id, embedding) self._vector_index.save() def close(self) -> None: self._store.close() def cleanup(self, max_items: int = 750, decay_factor: float = 0.95) -> int: """Remove low-weight entries when exceeding max_items limit. Applies weight decay based on freshness before cleanup. Returns number of removed entries. """ current_count = self._store.count() if current_count <= max_items: return 0 removed = 0 entries_to_remove = current_count - max_items all_entries = self._store.get_all(limit=current_count) def effective_weight(entry: MemoryEntry) -> float: entry_weight = entry.weight if entry.created_at: age_days = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 86400 freshness_factor = max(0.1, decay_factor ** age_days) return entry_weight * freshness_factor return entry_weight sorted_entries = sorted(all_entries, key=effective_weight) for entry in sorted_entries[:entries_to_remove]: self._store.delete(entry.id) removed += 1 return removed