ducklm/app/memory/interface.py

145 lines
4.6 KiB
Python

from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
import numpy as np
from app.core.contracts import MemoryEntry
from app.memory.store import MemoryStore
from app.memory.vector_index import VectorIndex
from app.models.embeddings import EmbeddingsAdapter
class MemoryInterface:
def __init__(
self,
store: MemoryStore,
vector_index: VectorIndex,
embeddings: EmbeddingsAdapter,
) -> None:
self._store = store
self._vector_index = vector_index
self._embeddings = embeddings
def insert(
self,
text: str,
kind: Literal["tool_result", "plan", "critique", "fact", "summary", "user_preference"],
source: Literal["tool", "critic", "user", "system"],
task_id: str | None = None,
session_id: str | None = None,
weight: float = 0.5,
metadata: dict[str, Any] | None = None,
) -> MemoryEntry:
entry = MemoryEntry(
text=text,
kind=kind,
source=source,
weight=weight,
task_id=task_id,
session_id=session_id,
metadata=metadata or {},
embedding_model=self._embeddings.__class__.__name__,
embedding_dim=self._embeddings.embedding_dim,
)
embedding = self._embeddings.encode(text)
embedding_bytes = embedding.astype("float32").tobytes()
self._store.insert(entry, embedding_bytes)
self._vector_index.insert(entry.id, embedding)
self._vector_index.save()
self.cleanup()
return entry
def search(
self,
query: str,
top_k: int = 5,
kind: str | None = None,
session_id: str | None = None,
) -> list[tuple[MemoryEntry, float]]:
query_embedding = self._embeddings.encode(query)
memory_ids, scores = self._vector_index.search(query_embedding, k=top_k)
results: list[tuple[MemoryEntry, float]] = []
for memory_id, score in zip(memory_ids, scores):
entry = self._store.get(memory_id)
if entry:
if kind and entry.kind != kind:
continue
if session_id and entry.session_id != session_id:
continue
results.append((entry, score))
return results[:top_k]
def get(self, memory_id: str) -> MemoryEntry | None:
return self._store.get(memory_id)
def delete(self, memory_id: str) -> bool:
entry = self._store.get(memory_id)
if entry:
self._vector_index.delete(memory_id)
return self._store.delete(memory_id)
return False
def get_by_task(self, task_id: str) -> list[MemoryEntry]:
return self._store.get_by_task(task_id)
def get_by_session(self, session_id: str, limit: int = 100) -> list[MemoryEntry]:
return self._store.get_by_session(session_id, limit)
def get_recent(self, limit: int = 10) -> list[MemoryEntry]:
return self._store.get_all(limit)
def count(self) -> int:
return self._store.count()
def reindex(self) -> None:
entries = self._store.get_all(limit=10000)
self._vector_index.save()
for entry in entries:
text = entry.text
embedding = self._embeddings.encode(text)
self._vector_index.insert(entry.id, embedding)
self._vector_index.save()
def close(self) -> None:
self._store.close()
def cleanup(self, max_items: int = 750, decay_factor: float = 0.95) -> int:
"""Remove low-weight entries when exceeding max_items limit.
Applies weight decay based on freshness before cleanup.
Returns number of removed entries.
"""
current_count = self._store.count()
if current_count <= max_items:
return 0
removed = 0
entries_to_remove = current_count - max_items
all_entries = self._store.get_all(limit=current_count)
def effective_weight(entry: MemoryEntry) -> float:
entry_weight = entry.weight
if entry.created_at:
age_days = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 86400
freshness_factor = max(0.1, decay_factor ** age_days)
return entry_weight * freshness_factor
return entry_weight
sorted_entries = sorted(all_entries, key=effective_weight)
for entry in sorted_entries[:entries_to_remove]:
self._store.delete(entry.id)
removed += 1
return removed