ducklm/app/memory/write_policy.py

98 lines
2.9 KiB
Python

from __future__ import annotations
from typing import Any, Literal
from app.core.contracts import CriticScore, MemoryEntry
class MemoryWritePolicy:
def __init__(
self,
store_threshold: float = 0.7,
min_usefulness: float = 0.3,
max_entries_per_session: int = 50,
) -> None:
self._store_threshold = store_threshold
self._min_usefulness = min_usefulness
self._max_entries_per_session = max_entries_per_session
def decide(
self,
critic_score: CriticScore,
memory_type: MemoryEntry.Kind,
session_id: str | None = None,
has_duplicate: bool = False,
current_session_count: int = 0,
) -> Literal["store", "store_with_weight", "skip", "merge"]:
if critic_score.safety < 0.5:
return "skip"
if has_duplicate:
return "merge"
if not critic_score.memory_store:
return "skip"
if critic_score.usefulness < self._min_usefulness:
return "skip"
if session_id and current_session_count >= self._max_entries_per_session:
return "skip"
base_decision = self._evaluate_scores(critic_score, memory_type)
if base_decision == "store" and critic_score.weight < self._store_threshold:
adjusted_weight = self._adjust_weight(critic_score, memory_type)
if adjusted_weight >= self._store_threshold:
return "store_with_weight"
return base_decision
return base_decision
def _evaluate_scores(
self,
critic_score: CriticScore,
memory_type: MemoryEntry.Kind,
) -> Literal["store", "store_with_weight", "skip", "merge"]:
avg_score = (critic_score.correctness + critic_score.usefulness + critic_score.safety) / 3.0
if memory_type in ("fact", "plan", "summary"):
if avg_score >= 0.8:
return "store"
elif avg_score >= 0.6:
return "store_with_weight"
if memory_type in ("tool_result", "critique"):
if avg_score >= self._store_threshold:
return "store"
elif avg_score >= 0.5:
return "store_with_weight"
if memory_type == "user_preference":
if avg_score >= 0.5:
return "store"
return "skip"
def _adjust_weight(
self,
critic_score: CriticScore,
memory_type: MemoryEntry.Kind,
) -> float:
base_weight = critic_score.weight
type_boost = {
"fact": 0.15,
"plan": 0.1,
"summary": 0.1,
"user_preference": 0.2,
"tool_result": 0.05,
"critique": 0.05,
}.get(memory_type, 0.0)
safety_boost = 0.0
if critic_score.safety >= 0.9:
safety_boost = 0.1
adjusted = base_weight + type_boost + safety_boost
return min(adjusted, 1.0)