- Layer 2 semantic search now queries both curated and raw types, closing the blind spot for turns past the 50-turn window pre-curation - Layer 3 skips turns already returned by Layer 2 to avoid duplicate context and wasted token budget - get_recent_turns uses Qdrant OrderBy for server-side timestamp sort with payload index; fallback to client-side sort if unavailable - Bump version to 2.0.4
257 lines
8.7 KiB
Python
257 lines
8.7 KiB
Python
"""Tests for QdrantService — all Qdrant and Ollama calls are mocked."""
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
|
|
def make_qdrant_service():
|
|
"""Create a QdrantService with mocked AsyncQdrantClient."""
|
|
with patch("app.qdrant_service.AsyncQdrantClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
MockClient.return_value = mock_client
|
|
|
|
from app.qdrant_service import QdrantService
|
|
svc = QdrantService(
|
|
host="http://localhost:6333",
|
|
collection="test_memories",
|
|
embedding_model="snowflake-arctic-embed2",
|
|
vector_size=1024,
|
|
ollama_host="http://localhost:11434",
|
|
)
|
|
|
|
return svc, mock_client
|
|
|
|
|
|
class TestEnsureCollection:
|
|
"""Tests for _ensure_collection."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_creates_collection_when_missing(self):
|
|
"""Creates collection if it does not exist."""
|
|
svc, mock_client = make_qdrant_service()
|
|
mock_client.get_collection = AsyncMock(side_effect=Exception("not found"))
|
|
mock_client.create_collection = AsyncMock()
|
|
|
|
await svc._ensure_collection()
|
|
|
|
mock_client.create_collection.assert_called_once()
|
|
assert svc._collection_ensured is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_if_collection_exists(self):
|
|
"""Does not create collection if it already exists."""
|
|
svc, mock_client = make_qdrant_service()
|
|
mock_client.get_collection = AsyncMock(return_value=MagicMock())
|
|
mock_client.create_collection = AsyncMock()
|
|
|
|
await svc._ensure_collection()
|
|
|
|
mock_client.create_collection.assert_not_called()
|
|
assert svc._collection_ensured is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_if_already_ensured(self):
|
|
"""Skips entirely if _collection_ensured is True."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
mock_client.get_collection = AsyncMock()
|
|
|
|
await svc._ensure_collection()
|
|
|
|
mock_client.get_collection.assert_not_called()
|
|
|
|
|
|
class TestGetEmbedding:
|
|
"""Tests for get_embedding."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_embedding_vector(self):
|
|
"""Returns embedding from Ollama response."""
|
|
svc, _ = make_qdrant_service()
|
|
fake_embedding = [0.1] * 1024
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {"embedding": fake_embedding}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await svc.get_embedding("test text")
|
|
|
|
assert result == fake_embedding
|
|
assert len(result) == 1024
|
|
|
|
|
|
class TestStoreTurn:
|
|
"""Tests for store_turn."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stores_raw_user_turn(self):
|
|
"""Stores a user turn with proper payload."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
mock_client.upsert = AsyncMock()
|
|
|
|
fake_embedding = [0.1] * 1024
|
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
|
point_id = await svc.store_turn(role="user", content="hello world")
|
|
|
|
assert isinstance(point_id, str)
|
|
mock_client.upsert.assert_called_once()
|
|
call_args = mock_client.upsert.call_args
|
|
point = call_args[1]["points"][0]
|
|
assert point.payload["type"] == "raw"
|
|
assert point.payload["role"] == "user"
|
|
assert "User: hello world" in point.payload["text"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stores_curated_turn(self):
|
|
"""Stores a curated turn without role prefix in text."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
mock_client.upsert = AsyncMock()
|
|
|
|
fake_embedding = [0.1] * 1024
|
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
|
point_id = await svc.store_turn(
|
|
role="curated",
|
|
content="User: q\nAssistant: a",
|
|
entry_type="curated"
|
|
)
|
|
|
|
call_args = mock_client.upsert.call_args
|
|
point = call_args[1]["points"][0]
|
|
assert point.payload["type"] == "curated"
|
|
# Curated text should be the content directly, not prefixed
|
|
assert point.payload["text"] == "User: q\nAssistant: a"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stores_with_topic_and_metadata(self):
|
|
"""Stores turn with optional topic and metadata."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
mock_client.upsert = AsyncMock()
|
|
|
|
fake_embedding = [0.1] * 1024
|
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
|
await svc.store_turn(
|
|
role="assistant",
|
|
content="some response",
|
|
topic="python",
|
|
metadata={"source": "test"}
|
|
)
|
|
|
|
call_args = mock_client.upsert.call_args
|
|
point = call_args[1]["points"][0]
|
|
assert point.payload["topic"] == "python"
|
|
assert point.payload["source"] == "test"
|
|
|
|
|
|
class TestStoreQaTurn:
|
|
"""Tests for store_qa_turn."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stores_qa_turn(self):
|
|
"""Stores a complete Q&A turn."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
mock_client.upsert = AsyncMock()
|
|
|
|
fake_embedding = [0.1] * 1024
|
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
|
point_id = await svc.store_qa_turn("What is Python?", "A programming language.")
|
|
|
|
assert isinstance(point_id, str)
|
|
call_args = mock_client.upsert.call_args
|
|
point = call_args[1]["points"][0]
|
|
assert point.payload["type"] == "raw"
|
|
assert point.payload["role"] == "qa"
|
|
assert "What is Python?" in point.payload["text"]
|
|
assert "A programming language." in point.payload["text"]
|
|
|
|
|
|
class TestSemanticSearch:
|
|
"""Tests for semantic_search."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_matching_results(self):
|
|
"""Returns formatted search results."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
|
|
mock_point = MagicMock()
|
|
mock_point.id = "result-1"
|
|
mock_point.score = 0.85
|
|
mock_point.payload = {"text": "User: hello\nAssistant: hi", "type": "curated"}
|
|
|
|
mock_query_result = MagicMock()
|
|
mock_query_result.points = [mock_point]
|
|
mock_client.query_points = AsyncMock(return_value=mock_query_result)
|
|
|
|
fake_embedding = [0.1] * 1024
|
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
|
results = await svc.semantic_search("hello", limit=10, score_threshold=0.6)
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == "result-1"
|
|
assert results[0]["score"] == 0.85
|
|
assert results[0]["payload"]["type"] == "curated"
|
|
|
|
|
|
class TestGetRecentTurns:
|
|
"""Tests for get_recent_turns."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_sorted_recent_turns(self):
|
|
"""Returns turns sorted by timestamp descending."""
|
|
svc, mock_client = make_qdrant_service()
|
|
svc._collection_ensured = True
|
|
|
|
mock_point1 = MagicMock()
|
|
mock_point1.id = "old"
|
|
mock_point1.payload = {"timestamp": "2026-01-01T00:00:00Z", "text": "old turn"}
|
|
|
|
mock_point2 = MagicMock()
|
|
mock_point2.id = "new"
|
|
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
|
|
|
|
# OrderBy returns server-sorted results (newest first)
|
|
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
|
|
|
|
results = await svc.get_recent_turns(limit=2)
|
|
|
|
assert len(results) == 2
|
|
# Newest first (server-sorted via OrderBy)
|
|
assert results[0]["id"] == "new"
|
|
assert results[1]["id"] == "old"
|
|
|
|
|
|
class TestDeletePoints:
|
|
"""Tests for delete_points."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deletes_by_ids(self):
|
|
"""Deletes points by their IDs."""
|
|
svc, mock_client = make_qdrant_service()
|
|
mock_client.delete = AsyncMock()
|
|
|
|
await svc.delete_points(["id1", "id2"])
|
|
|
|
mock_client.delete.assert_called_once()
|
|
|
|
|
|
class TestClose:
|
|
"""Tests for close."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_closes_client(self):
|
|
"""Closes the async Qdrant client."""
|
|
svc, mock_client = make_qdrant_service()
|
|
mock_client.close = AsyncMock()
|
|
|
|
await svc.close()
|
|
|
|
mock_client.close.assert_called_once()
|