98 lines
3.7 KiB
Python
98 lines
3.7 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from duck_core.memory.vector_memory import VectorMemory, EmbeddingsUnavailableError
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_memory_uses_local_model():
|
|
"""Test VectorMemory with local sentence-transformers model (mocked)."""
|
|
vm = VectorMemory(
|
|
qdrant_url="http://localhost:6333",
|
|
local_embedding_model="./models/all-MiniLM-L6-v2",
|
|
)
|
|
|
|
# Mock the sentence-transformers model — encode returns a numpy-like list
|
|
mock_model = MagicMock()
|
|
mock_model.encode.return_value = [0.1] * 384 # all-MiniLM-L6-v2 produces 384-dim vectors
|
|
|
|
with patch.object(vm, "_load_local_model", return_value=mock_model):
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
put_response = MagicMock(status_code=200)
|
|
put_response.raise_for_status = MagicMock()
|
|
search_response = MagicMock(status_code=200)
|
|
search_response.raise_for_status = MagicMock()
|
|
search_response.json.return_value = {
|
|
"result": [{"id": "test-id", "payload": {"text": "test"}}]
|
|
}
|
|
mock_client.put.return_value = put_response
|
|
mock_client.post.return_value = search_response
|
|
|
|
point_id = await vm.add_memory("test memory", {"scope": "global"})
|
|
assert point_id is not None
|
|
|
|
results = await vm.search_memory("test query")
|
|
assert isinstance(results, list)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_memory_no_embedding_source():
|
|
"""VectorMemory with no embedding source should raise."""
|
|
vm = VectorMemory(
|
|
qdrant_url="http://localhost:6333",
|
|
local_embedding_model=None,
|
|
embeddings_base_url=None,
|
|
)
|
|
|
|
with pytest.raises(EmbeddingsUnavailableError):
|
|
await vm.add_memory("test")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_memory_remote_fallback():
|
|
"""Test VectorMemory with remote embeddings endpoint."""
|
|
vm = VectorMemory(
|
|
qdrant_url="http://localhost:6333",
|
|
local_embedding_model=None,
|
|
embeddings_base_url="http://localhost:8081/v1",
|
|
)
|
|
|
|
mock_embedding = [0.1] * 384
|
|
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
embed_response = MagicMock(status_code=200)
|
|
embed_response.json.return_value = {"data": [{"embedding": mock_embedding}]}
|
|
put_response = MagicMock(status_code=200)
|
|
put_response.raise_for_status = MagicMock()
|
|
mock_client.post.return_value = embed_response
|
|
mock_client.put.return_value = put_response
|
|
|
|
point_id = await vm.add_memory("test")
|
|
assert point_id is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_memory_remote_503():
|
|
"""Remote embeddings returning 503 should raise EmbeddingsUnavailableError."""
|
|
vm = VectorMemory(
|
|
qdrant_url="http://localhost:6333",
|
|
local_embedding_model=None,
|
|
embeddings_base_url="http://localhost:8081/v1",
|
|
)
|
|
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
mock_client.post.return_value = AsyncMock(status_code=503)
|
|
|
|
with pytest.raises(EmbeddingsUnavailableError):
|
|
await vm.add_memory("test")
|