from __future__ import annotations from typing import Any, Literal from app.core.contracts import CriticScore, MemoryEntry class MemoryWritePolicy: def __init__( self, store_threshold: float = 0.7, min_usefulness: float = 0.3, max_entries_per_session: int = 50, ) -> None: self._store_threshold = store_threshold self._min_usefulness = min_usefulness self._max_entries_per_session = max_entries_per_session def decide( self, critic_score: CriticScore, memory_type: MemoryEntry.Kind, session_id: str | None = None, has_duplicate: bool = False, current_session_count: int = 0, ) -> Literal["store", "store_with_weight", "skip", "merge"]: if critic_score.safety < 0.5: return "skip" if has_duplicate: return "merge" if not critic_score.memory_store: return "skip" if critic_score.usefulness < self._min_usefulness: return "skip" if session_id and current_session_count >= self._max_entries_per_session: return "skip" base_decision = self._evaluate_scores(critic_score, memory_type) if base_decision == "store" and critic_score.weight < self._store_threshold: adjusted_weight = self._adjust_weight(critic_score, memory_type) if adjusted_weight >= self._store_threshold: return "store_with_weight" return base_decision return base_decision def _evaluate_scores( self, critic_score: CriticScore, memory_type: MemoryEntry.Kind, ) -> Literal["store", "store_with_weight", "skip", "merge"]: avg_score = (critic_score.correctness + critic_score.usefulness + critic_score.safety) / 3.0 if memory_type in ("fact", "plan", "summary"): if avg_score >= 0.8: return "store" elif avg_score >= 0.6: return "store_with_weight" if memory_type in ("tool_result", "critique"): if avg_score >= self._store_threshold: return "store" elif avg_score >= 0.5: return "store_with_weight" if memory_type == "user_preference": if avg_score >= 0.5: return "store" return "skip" def _adjust_weight( self, critic_score: CriticScore, memory_type: MemoryEntry.Kind, ) -> float: base_weight = critic_score.weight type_boost = { "fact": 0.15, "plan": 0.1, "summary": 0.1, "user_preference": 0.2, "tool_result": 0.05, "critique": 0.05, }.get(memory_type, 0.0) safety_boost = 0.0 if critic_score.safety >= 0.9: safety_boost = 0.1 adjusted = base_weight + type_boost + safety_boost return min(adjusted, 1.0)