124 lines
4.2 KiB
Python
124 lines
4.2 KiB
Python
import json
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from duck_core.memory.policy import MemoryPolicy
|
|
from duck_core.model_client import ModelClient, ModelResponse
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model_client():
|
|
client = AsyncMock(spec=ModelClient)
|
|
client.chat = AsyncMock(
|
|
return_value=ModelResponse(
|
|
role="critic",
|
|
model="local-main",
|
|
content=json.dumps({
|
|
"should_store": True,
|
|
"memory_type": "preference",
|
|
"summary": "User prefers concise Russian answers.",
|
|
"importance": 0.9,
|
|
"scope": "global",
|
|
"metadata": {"source": "conversation"},
|
|
}),
|
|
reasoning_content=None,
|
|
raw={},
|
|
latency_ms=42.0,
|
|
)
|
|
)
|
|
return client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_policy_stub_returns_should_store_false():
|
|
policy = MemoryPolicy()
|
|
decision = await policy.classify("some summary", "task_123")
|
|
assert decision.should_store is False
|
|
assert decision.memory_type == "event"
|
|
assert decision.importance == 0.0
|
|
assert decision.metadata["source"] == "stub_policy"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_memory_policy_classifies_and_stores(mock_model_client):
|
|
policy = MemoryPolicy(model_client=mock_model_client, role="memory_policy")
|
|
decision = await policy.classify(
|
|
"User said they prefer short answers in Russian.", "task_456"
|
|
)
|
|
assert decision.should_store is True
|
|
assert decision.memory_type == "preference"
|
|
assert decision.importance == 0.9
|
|
assert decision.summary == "User prefers concise Russian answers."
|
|
mock_model_client.chat.assert_called_once()
|
|
call_args = mock_model_client.chat.call_args
|
|
# ModelClient.chat(role, messages, ...) — positional args
|
|
assert call_args.args[0] == "memory_policy"
|
|
messages = call_args.args[1]
|
|
assert len(messages) == 1
|
|
assert messages[0]["role"] == "user"
|
|
assert "User said they prefer short answers" in messages[0]["content"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_memory_policy_handles_non_storable(mock_model_client):
|
|
mock_model_client.chat.return_value = ModelResponse(
|
|
role="critic",
|
|
model="local-main",
|
|
content=json.dumps({
|
|
"should_store": False,
|
|
"memory_type": "event",
|
|
"summary": "Routine tool call, nothing to remember.",
|
|
"importance": 0.1,
|
|
"scope": "workspace",
|
|
"metadata": {},
|
|
}),
|
|
reasoning_content=None,
|
|
raw={},
|
|
latency_ms=30.0,
|
|
)
|
|
policy = MemoryPolicy(model_client=mock_model_client)
|
|
decision = await policy.classify("Ran ls -la in workspace.", "task_789")
|
|
assert decision.should_store is False
|
|
assert decision.importance == 0.1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_memory_policy_uses_response_format(mock_model_client):
|
|
policy = MemoryPolicy(model_client=mock_model_client)
|
|
await policy.classify("test summary", "task_1")
|
|
call_args = mock_model_client.chat.call_args
|
|
assert call_args.kwargs["response_format"]["type"] == "json_schema"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_memory_policy_invalid_json_falls_back(mock_model_client):
|
|
mock_model_client.chat.return_value = ModelResponse(
|
|
role="critic",
|
|
model="local-main",
|
|
content="not valid json {{{",
|
|
reasoning_content=None,
|
|
raw={},
|
|
latency_ms=10.0,
|
|
)
|
|
policy = MemoryPolicy(model_client=mock_model_client)
|
|
decision = await policy.classify("some summary", "task_x")
|
|
assert decision.should_store is False
|
|
assert decision.metadata["source"] == "llm_policy_fallback"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_memory_policy_missing_fields_falls_back(mock_model_client):
|
|
mock_model_client.chat.return_value = ModelResponse(
|
|
role="critic",
|
|
model="local-main",
|
|
content=json.dumps({"should_store": True}),
|
|
reasoning_content=None,
|
|
raw={},
|
|
latency_ms=10.0,
|
|
)
|
|
policy = MemoryPolicy(model_client=mock_model_client)
|
|
decision = await policy.classify("some summary", "task_y")
|
|
assert decision.should_store is False
|
|
assert decision.metadata["source"] == "llm_policy_fallback"
|