98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Literal
|
|
|
|
from app.core.contracts import CriticScore, MemoryEntry
|
|
|
|
|
|
class MemoryWritePolicy:
|
|
def __init__(
|
|
self,
|
|
store_threshold: float = 0.7,
|
|
min_usefulness: float = 0.3,
|
|
max_entries_per_session: int = 50,
|
|
) -> None:
|
|
self._store_threshold = store_threshold
|
|
self._min_usefulness = min_usefulness
|
|
self._max_entries_per_session = max_entries_per_session
|
|
|
|
def decide(
|
|
self,
|
|
critic_score: CriticScore,
|
|
memory_type: MemoryEntry.Kind,
|
|
session_id: str | None = None,
|
|
has_duplicate: bool = False,
|
|
current_session_count: int = 0,
|
|
) -> Literal["store", "store_with_weight", "skip", "merge"]:
|
|
if critic_score.safety < 0.5:
|
|
return "skip"
|
|
|
|
if has_duplicate:
|
|
return "merge"
|
|
|
|
if not critic_score.memory_store:
|
|
return "skip"
|
|
|
|
if critic_score.usefulness < self._min_usefulness:
|
|
return "skip"
|
|
|
|
if session_id and current_session_count >= self._max_entries_per_session:
|
|
return "skip"
|
|
|
|
base_decision = self._evaluate_scores(critic_score, memory_type)
|
|
|
|
if base_decision == "store" and critic_score.weight < self._store_threshold:
|
|
adjusted_weight = self._adjust_weight(critic_score, memory_type)
|
|
if adjusted_weight >= self._store_threshold:
|
|
return "store_with_weight"
|
|
return base_decision
|
|
|
|
return base_decision
|
|
|
|
def _evaluate_scores(
|
|
self,
|
|
critic_score: CriticScore,
|
|
memory_type: MemoryEntry.Kind,
|
|
) -> Literal["store", "store_with_weight", "skip", "merge"]:
|
|
avg_score = (critic_score.correctness + critic_score.usefulness + critic_score.safety) / 3.0
|
|
|
|
if memory_type in ("fact", "plan", "summary"):
|
|
if avg_score >= 0.8:
|
|
return "store"
|
|
elif avg_score >= 0.6:
|
|
return "store_with_weight"
|
|
|
|
if memory_type in ("tool_result", "critique"):
|
|
if avg_score >= self._store_threshold:
|
|
return "store"
|
|
elif avg_score >= 0.5:
|
|
return "store_with_weight"
|
|
|
|
if memory_type == "user_preference":
|
|
if avg_score >= 0.5:
|
|
return "store"
|
|
|
|
return "skip"
|
|
|
|
def _adjust_weight(
|
|
self,
|
|
critic_score: CriticScore,
|
|
memory_type: MemoryEntry.Kind,
|
|
) -> float:
|
|
base_weight = critic_score.weight
|
|
|
|
type_boost = {
|
|
"fact": 0.15,
|
|
"plan": 0.1,
|
|
"summary": 0.1,
|
|
"user_preference": 0.2,
|
|
"tool_result": 0.05,
|
|
"critique": 0.05,
|
|
}.get(memory_type, 0.0)
|
|
|
|
safety_boost = 0.0
|
|
if critic_score.safety >= 0.9:
|
|
safety_boost = 0.1
|
|
|
|
adjusted = base_weight + type_boost + safety_boost
|
|
return min(adjusted, 1.0) |