Files
vera-ai-v2/tests/test_qdrant_service.py
Vera-AI 346f2c26fe feat: semantic search includes raw turns, deduplicate layers, fix recent turn ordering
- Layer 2 semantic search now queries both curated and raw types,
  closing the blind spot for turns past the 50-turn window pre-curation
- Layer 3 skips turns already returned by Layer 2 to avoid duplicate
  context and wasted token budget
- get_recent_turns uses Qdrant OrderBy for server-side timestamp sort
  with payload index; fallback to client-side sort if unavailable
- Bump version to 2.0.4
2026-04-01 17:43:47 -05:00

257 lines
8.7 KiB
Python

"""Tests for QdrantService — all Qdrant and Ollama calls are mocked."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
def make_qdrant_service():
"""Create a QdrantService with mocked AsyncQdrantClient."""
with patch("app.qdrant_service.AsyncQdrantClient") as MockClient:
mock_client = AsyncMock()
MockClient.return_value = mock_client
from app.qdrant_service import QdrantService
svc = QdrantService(
host="http://localhost:6333",
collection="test_memories",
embedding_model="snowflake-arctic-embed2",
vector_size=1024,
ollama_host="http://localhost:11434",
)
return svc, mock_client
class TestEnsureCollection:
"""Tests for _ensure_collection."""
@pytest.mark.asyncio
async def test_creates_collection_when_missing(self):
"""Creates collection if it does not exist."""
svc, mock_client = make_qdrant_service()
mock_client.get_collection = AsyncMock(side_effect=Exception("not found"))
mock_client.create_collection = AsyncMock()
await svc._ensure_collection()
mock_client.create_collection.assert_called_once()
assert svc._collection_ensured is True
@pytest.mark.asyncio
async def test_skips_if_collection_exists(self):
"""Does not create collection if it already exists."""
svc, mock_client = make_qdrant_service()
mock_client.get_collection = AsyncMock(return_value=MagicMock())
mock_client.create_collection = AsyncMock()
await svc._ensure_collection()
mock_client.create_collection.assert_not_called()
assert svc._collection_ensured is True
@pytest.mark.asyncio
async def test_skips_if_already_ensured(self):
"""Skips entirely if _collection_ensured is True."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.get_collection = AsyncMock()
await svc._ensure_collection()
mock_client.get_collection.assert_not_called()
class TestGetEmbedding:
"""Tests for get_embedding."""
@pytest.mark.asyncio
async def test_returns_embedding_vector(self):
"""Returns embedding from Ollama response."""
svc, _ = make_qdrant_service()
fake_embedding = [0.1] * 1024
mock_resp = MagicMock()
mock_resp.json.return_value = {"embedding": fake_embedding}
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.post = AsyncMock(return_value=mock_resp)
with patch("httpx.AsyncClient", return_value=mock_client):
result = await svc.get_embedding("test text")
assert result == fake_embedding
assert len(result) == 1024
class TestStoreTurn:
"""Tests for store_turn."""
@pytest.mark.asyncio
async def test_stores_raw_user_turn(self):
"""Stores a user turn with proper payload."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
point_id = await svc.store_turn(role="user", content="hello world")
assert isinstance(point_id, str)
mock_client.upsert.assert_called_once()
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["type"] == "raw"
assert point.payload["role"] == "user"
assert "User: hello world" in point.payload["text"]
@pytest.mark.asyncio
async def test_stores_curated_turn(self):
"""Stores a curated turn without role prefix in text."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
point_id = await svc.store_turn(
role="curated",
content="User: q\nAssistant: a",
entry_type="curated"
)
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["type"] == "curated"
# Curated text should be the content directly, not prefixed
assert point.payload["text"] == "User: q\nAssistant: a"
@pytest.mark.asyncio
async def test_stores_with_topic_and_metadata(self):
"""Stores turn with optional topic and metadata."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
await svc.store_turn(
role="assistant",
content="some response",
topic="python",
metadata={"source": "test"}
)
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["topic"] == "python"
assert point.payload["source"] == "test"
class TestStoreQaTurn:
"""Tests for store_qa_turn."""
@pytest.mark.asyncio
async def test_stores_qa_turn(self):
"""Stores a complete Q&A turn."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
point_id = await svc.store_qa_turn("What is Python?", "A programming language.")
assert isinstance(point_id, str)
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["type"] == "raw"
assert point.payload["role"] == "qa"
assert "What is Python?" in point.payload["text"]
assert "A programming language." in point.payload["text"]
class TestSemanticSearch:
"""Tests for semantic_search."""
@pytest.mark.asyncio
async def test_returns_matching_results(self):
"""Returns formatted search results."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_point = MagicMock()
mock_point.id = "result-1"
mock_point.score = 0.85
mock_point.payload = {"text": "User: hello\nAssistant: hi", "type": "curated"}
mock_query_result = MagicMock()
mock_query_result.points = [mock_point]
mock_client.query_points = AsyncMock(return_value=mock_query_result)
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
results = await svc.semantic_search("hello", limit=10, score_threshold=0.6)
assert len(results) == 1
assert results[0]["id"] == "result-1"
assert results[0]["score"] == 0.85
assert results[0]["payload"]["type"] == "curated"
class TestGetRecentTurns:
"""Tests for get_recent_turns."""
@pytest.mark.asyncio
async def test_returns_sorted_recent_turns(self):
"""Returns turns sorted by timestamp descending."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_point1 = MagicMock()
mock_point1.id = "old"
mock_point1.payload = {"timestamp": "2026-01-01T00:00:00Z", "text": "old turn"}
mock_point2 = MagicMock()
mock_point2.id = "new"
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
# OrderBy returns server-sorted results (newest first)
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
results = await svc.get_recent_turns(limit=2)
assert len(results) == 2
# Newest first (server-sorted via OrderBy)
assert results[0]["id"] == "new"
assert results[1]["id"] == "old"
class TestDeletePoints:
"""Tests for delete_points."""
@pytest.mark.asyncio
async def test_deletes_by_ids(self):
"""Deletes points by their IDs."""
svc, mock_client = make_qdrant_service()
mock_client.delete = AsyncMock()
await svc.delete_points(["id1", "id2"])
mock_client.delete.assert_called_once()
class TestClose:
"""Tests for close."""
@pytest.mark.asyncio
async def test_closes_client(self):
"""Closes the async Qdrant client."""
svc, mock_client = make_qdrant_service()
mock_client.close = AsyncMock()
await svc.close()
mock_client.close.assert_called_once()