from __future__ import annotations import json import logging from typing import Any from pydantic import BaseModel logger = logging.getLogger(__name__) class MemoryDecision(BaseModel): should_store: bool = False memory_type: str = "note" summary: str = "" importance: float = 0.0 scope: str = "workspace" metadata: dict[str, str] = {} class MemoryPolicy: """Decides whether task output should be stored in memory. When *model_client* is provided, uses an LLM call to classify the task transcript. Falls back to a safe default (should_store=False) on any error so the runtime is never blocked by policy failures. """ _PROMPT_SYSTEM = ( "You are DuckLM memory policy. Decide whether the given task transcript " "contains information worth storing in long-term memory.\n\n" "Return ONLY valid JSON with these keys:\n" " should_store: boolean — true if this is worth remembering\n" " memory_type: string — one of: fact, preference, lesson, decision, event, note\n" " summary: string — concise one-sentence summary (max 200 chars)\n" " importance: number — 0.0 to 1.0\n" " scope: string — one of: global, workspace, conversation\n" " metadata: object — optional extra key-value pairs\n\n" "Rules:\n" "- Store user preferences, important decisions, reusable lessons, key facts.\n" "- Do NOT store routine tool calls, temporary state, or trivial observations.\n" "- importance >= 0.7 for preferences and lessons, >= 0.4 for facts, < 0.4 for events.\n" "- scope='global' for user preferences and system-wide facts.\n" "- scope='workspace' for project-specific information.\n" "- scope='conversation' for chat-specific context.\n" ) _RESPONSE_SCHEMA = { "type": "object", "required": ["should_store", "memory_type", "summary", "importance", "scope", "metadata"], "additionalProperties": False, "properties": { "should_store": {"type": "boolean"}, "memory_type": { "type": "string", "enum": ["fact", "preference", "lesson", "decision", "event", "note"], }, "summary": {"type": "string", "maxLength": 300}, "importance": {"type": "number", "minimum": 0.0, "maximum": 1.0}, "scope": {"type": "string", "enum": ["global", "workspace", "conversation"]}, "metadata": {"type": "object", "additionalProperties": {"type": "string"}}, }, } def __init__( self, model_client: Any | None = None, role: str = "memory_policy", ): self._model_client = model_client self._role = role async def classify(self, summary: str, task_id: str) -> MemoryDecision: """Classify whether *summary* from *task_id* should be stored in memory. If no model client is configured, returns the safe default (should_store=False) — the old stub behaviour. """ if self._model_client is None: return MemoryDecision( should_store=False, memory_type="event", summary=summary, importance=0.0, metadata={"task_id": task_id, "source": "stub_policy"}, ) return await self._classify_with_llm(summary, task_id) async def _classify_with_llm(self, summary: str, task_id: str) -> MemoryDecision: messages = [ { "role": "user", "content": f"Task ID: {task_id}\n\nTranscript:\n{summary}", } ] response_format = { "type": "json_schema", "json_schema": { "name": "memory_decision", "schema": self._RESPONSE_SCHEMA, "strict": True, }, } try: response = await self._model_client.chat( self._role, messages, response_format=response_format, ) except Exception as exc: logger.warning("MemoryPolicy LLM call failed for %s: %s", task_id, exc) return MemoryDecision( should_store=False, memory_type="event", summary=summary, importance=0.0, metadata={"task_id": task_id, "source": "llm_policy_fallback"}, ) return self._parse_response(response.content, summary, task_id) def _parse_response(self, content: str, summary: str, task_id: str) -> MemoryDecision: try: data = json.loads(content) except (json.JSONDecodeError, TypeError): logger.warning("MemoryPolicy: invalid JSON for %s: %s", task_id, content[:200]) return MemoryDecision( should_store=False, memory_type="event", summary=summary, importance=0.0, metadata={"task_id": task_id, "source": "llm_policy_fallback"}, ) required = ("should_store", "memory_type", "summary", "importance", "scope") if not all(key in data for key in required): logger.warning("MemoryPolicy: missing fields for %s: %s", task_id, list(data.keys())) return MemoryDecision( should_store=False, memory_type="event", summary=summary, importance=0.0, metadata={"task_id": task_id, "source": "llm_policy_fallback"}, ) return MemoryDecision( should_store=bool(data.get("should_store", False)), memory_type=str(data.get("memory_type", "note")), summary=str(data.get("summary", summary))[:300], importance=float(max(0.0, min(data.get("importance", 0.0), 1.0))), scope=str(data.get("scope", "workspace")), metadata={ "task_id": task_id, "source": "llm_policy", **{str(k): str(v) for k, v in data.get("metadata", {}).items()}, }, )