diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..bfb878c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,62 @@ +"""Shared test fixtures using production-realistic data.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from app.config import Config + + +@pytest.fixture +def production_config(): + """Config matching production deployment on deb8.""" + config = MagicMock(spec=Config) + config.ollama_host = "http://10.0.0.10:11434" + config.qdrant_host = "http://10.0.0.22:6333" + config.qdrant_collection = "memories" + config.embedding_model = "snowflake-arctic-embed2" + config.semantic_token_budget = 25000 + config.context_token_budget = 22000 + config.semantic_search_turns = 2 + config.semantic_score_threshold = 0.6 + config.run_time = "02:00" + config.curator_model = "gpt-oss:120b" + config.debug = False + config.vector_size = 1024 + config.cloud = MagicMock() + config.cloud.enabled = False + config.cloud.models = {} + config.cloud.get_cloud_model.return_value = None + return config + + +@pytest.fixture +def sample_qdrant_raw_payload(): + """Sample raw payload from production Qdrant.""" + return { + "type": "raw", + "text": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z", + "timestamp": "2026-03-27T12:50:37.451593Z", + "role": "qa", + "content": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z" + } + + +@pytest.fixture +def sample_ollama_models(): + """Model list from production Ollama.""" + return { + "models": [ + { + "name": "snowflake-arctic-embed2:latest", + "model": "snowflake-arctic-embed2:latest", + "modified_at": "2026-02-16T16:43:44Z", + "size": 1160296718, + "details": {"family": "bert", "parameter_size": "566.70M", "quantization_level": "F16"} + }, + { + "name": "gpt-oss:120b", + "model": "gpt-oss:120b", + "modified_at": "2026-03-11T12:45:48Z", + "size": 65369818941, + "details": {"family": "gptoss", "parameter_size": "116.8B", "quantization_level": "MXFP4"} + } + ] + } diff --git a/tests/test_curator.py b/tests/test_curator.py index 5e91998..4e98ef1 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -4,7 +4,7 @@ import json import os from datetime import datetime, timedelta, timezone from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, AsyncMock, patch def make_curator(): @@ -197,3 +197,294 @@ class TestAppendRuleToFile: target = prompts_dir / "newfile.md" assert target.exists() assert "New rule here." in target.read_text() + + +class TestFormatExistingMemories: + """Tests for Curator._format_existing_memories.""" + + def test_empty_list_returns_no_memories_message(self): + """Empty list returns a 'no memories' message.""" + curator, _ = make_curator() + result = curator._format_existing_memories([]) + assert "No existing curated memories" in result + + def test_single_memory_formatted(self): + """Single memory text is included in output.""" + curator, _ = make_curator() + memories = [{"text": "User: hello\nAssistant: hi there"}] + result = curator._format_existing_memories(memories) + assert "hello" in result + assert "hi there" in result + + def test_limits_to_last_20(self): + """Only last 20 memories are included.""" + curator, _ = make_curator() + memories = [{"text": f"memory {i}"} for i in range(30)] + result = curator._format_existing_memories(memories) + # Should contain memory 10-29 (last 20), not memory 0-9 + assert "memory 29" in result + assert "memory 10" in result + + +class TestCallLlm: + """Tests for Curator._call_llm.""" + + @pytest.mark.asyncio + async def test_call_llm_returns_response(self): + """_call_llm returns the response text from Ollama.""" + curator, _ = make_curator() + + mock_resp = MagicMock() + mock_resp.json.return_value = {"response": "some LLM output"} + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_resp) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await curator._call_llm("test prompt") + + assert result == "some LLM output" + call_kwargs = mock_client.post.call_args + assert "test-model" in call_kwargs[1]["json"]["model"] + + @pytest.mark.asyncio + async def test_call_llm_returns_empty_on_error(self): + """_call_llm returns empty string when Ollama errors.""" + curator, _ = make_curator() + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(side_effect=Exception("connection refused")) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await curator._call_llm("test prompt") + + assert result == "" + + +class TestCuratorRun: + """Tests for Curator.run() method.""" + + @pytest.mark.asyncio + async def test_run_no_raw_memories_exits_early(self): + """run() exits early when no raw memories found.""" + curator, mock_qdrant = make_curator() + + # Mock scroll to return no points + mock_qdrant.client = AsyncMock() + mock_qdrant.client.scroll = AsyncMock(return_value=([], None)) + mock_qdrant.collection = "memories" + + await curator.run() + # Should not call LLM since there are no raw memories + # If it got here without error, that's success + + @pytest.mark.asyncio + async def test_run_processes_raw_memories(self): + """run() processes raw memories and stores curated results.""" + curator, mock_qdrant = make_curator() + + # Create mock points + mock_point = MagicMock() + mock_point.id = "point-1" + mock_point.payload = { + "type": "raw", + "text": "User: hello\nAssistant: hi", + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + + mock_qdrant.client = AsyncMock() + mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None)) + mock_qdrant.collection = "memories" + mock_qdrant.store_turn = AsyncMock(return_value="new-id") + mock_qdrant.delete_points = AsyncMock() + + llm_response = json.dumps({ + "new_curated_turns": [{"content": "User: hello\nAssistant: hi"}], + "permanent_rules": [], + "deletions": [], + "summary": "Curated one turn" + }) + + with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)): + await curator.run() + + mock_qdrant.store_turn.assert_called_once() + mock_qdrant.delete_points.assert_called() + + @pytest.mark.asyncio + async def test_run_monthly_mode_on_day_01(self): + """run() uses monthly mode on day 01, processing all raw memories.""" + curator, mock_qdrant = make_curator() + + # Create a mock point with an old timestamp (outside 24h window) + old_ts = (datetime.now(timezone.utc) - timedelta(hours=72)).isoformat().replace("+00:00", "Z") + mock_point = MagicMock() + mock_point.id = "old-point" + mock_point.payload = { + "type": "raw", + "text": "User: old question\nAssistant: old answer", + "timestamp": old_ts, + } + + mock_qdrant.client = AsyncMock() + mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None)) + mock_qdrant.collection = "memories" + mock_qdrant.store_turn = AsyncMock(return_value="new-id") + mock_qdrant.delete_points = AsyncMock() + + llm_response = json.dumps({ + "new_curated_turns": [], + "permanent_rules": [], + "deletions": [], + "summary": "Nothing to curate" + }) + + # Mock day 01 + mock_now = datetime(2026, 4, 1, 2, 0, 0, tzinfo=timezone.utc) + with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \ + patch("app.curator.datetime") as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw) + await curator.run() + + # In monthly mode, even old memories are processed, so LLM should be called + # and delete_points should be called for the raw memory + mock_qdrant.delete_points.assert_called() + + @pytest.mark.asyncio + async def test_run_handles_permanent_rules(self): + """run() appends permanent rules to prompt files.""" + curator, mock_qdrant = make_curator() + + mock_point = MagicMock() + mock_point.id = "point-1" + mock_point.payload = { + "type": "raw", + "text": "User: remember this\nAssistant: ok", + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + + mock_qdrant.client = AsyncMock() + mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None)) + mock_qdrant.collection = "memories" + mock_qdrant.store_turn = AsyncMock(return_value="new-id") + mock_qdrant.delete_points = AsyncMock() + + llm_response = json.dumps({ + "new_curated_turns": [], + "permanent_rules": [{"rule": "Always be concise.", "target_file": "systemprompt.md"}], + "deletions": [], + "summary": "Added a rule" + }) + + with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \ + patch.object(curator, "_append_rule_to_file", AsyncMock()) as mock_append: + await curator.run() + + mock_append.assert_called_once_with("systemprompt.md", "Always be concise.") + + @pytest.mark.asyncio + async def test_run_handles_deletions(self): + """run() deletes specified point IDs when they exist in the database.""" + curator, mock_qdrant = make_curator() + + mock_point = MagicMock() + mock_point.id = "point-1" + mock_point.payload = { + "type": "raw", + "text": "User: delete me\nAssistant: ok", + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + + mock_qdrant.client = AsyncMock() + mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None)) + mock_qdrant.collection = "memories" + mock_qdrant.store_turn = AsyncMock(return_value="new-id") + mock_qdrant.delete_points = AsyncMock() + + llm_response = json.dumps({ + "new_curated_turns": [], + "permanent_rules": [], + "deletions": ["point-1"], + "summary": "Deleted one" + }) + + with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)): + await curator.run() + + # delete_points should be called at least twice: once for valid deletions, once for processed raw + assert mock_qdrant.delete_points.call_count >= 1 + + @pytest.mark.asyncio + async def test_run_handles_llm_parse_failure(self): + """run() handles LLM returning unparseable response gracefully.""" + curator, mock_qdrant = make_curator() + + mock_point = MagicMock() + mock_point.id = "point-1" + mock_point.payload = { + "type": "raw", + "text": "User: test\nAssistant: ok", + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + + mock_qdrant.client = AsyncMock() + mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None)) + mock_qdrant.collection = "memories" + + with patch.object(curator, "_call_llm", AsyncMock(return_value="not json at all!!!")): + # Should not raise - just return early + await curator.run() + + # store_turn should NOT be called since parsing failed + mock_qdrant.store_turn = AsyncMock() + mock_qdrant.store_turn.assert_not_called() + + +class TestLoadCuratorPrompt: + """Tests for load_curator_prompt function.""" + + def test_loads_from_prompts_dir(self, tmp_path): + """load_curator_prompt loads from PROMPTS_DIR.""" + import app.curator as curator_module + + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "curator_prompt.md").write_text("Test curator prompt") + + with patch.object(curator_module, "PROMPTS_DIR", prompts_dir): + from app.curator import load_curator_prompt + result = load_curator_prompt() + + assert result == "Test curator prompt" + + def test_falls_back_to_static_dir(self, tmp_path): + """load_curator_prompt falls back to STATIC_DIR.""" + import app.curator as curator_module + + prompts_dir = tmp_path / "prompts" # does not exist + static_dir = tmp_path / "static" + static_dir.mkdir() + (static_dir / "curator_prompt.md").write_text("Static prompt") + + with patch.object(curator_module, "PROMPTS_DIR", prompts_dir), \ + patch.object(curator_module, "STATIC_DIR", static_dir): + from app.curator import load_curator_prompt + result = load_curator_prompt() + + assert result == "Static prompt" + + def test_raises_when_not_found(self, tmp_path): + """load_curator_prompt raises FileNotFoundError when file missing.""" + import app.curator as curator_module + + with patch.object(curator_module, "PROMPTS_DIR", tmp_path / "nope"), \ + patch.object(curator_module, "STATIC_DIR", tmp_path / "also_nope"): + from app.curator import load_curator_prompt + with pytest.raises(FileNotFoundError): + load_curator_prompt() diff --git a/tests/test_integration.py b/tests/test_integration.py index 3b93ffc..072a781 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -349,3 +349,83 @@ class TestApiChatStreaming: # Response body should contain both chunks concatenated body_text = resp.text assert "Hello" in body_text or len(body_text) > 0 + + +# --------------------------------------------------------------------------- +# Health check edge cases +# --------------------------------------------------------------------------- + +class TestHealthCheckEdgeCases: + def test_health_ollama_timeout(self, app_with_mocks): + """GET / handles Ollama timeout gracefully.""" + import httpx + from fastapi.testclient import TestClient + + vera_app, _ = app_with_mocks + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.get = AsyncMock(side_effect=httpx.TimeoutException("timeout")) + + with patch("httpx.AsyncClient", return_value=mock_client_instance): + with TestClient(vera_app, raise_server_exceptions=True) as client: + resp = client.get("/") + + assert resp.status_code == 200 + assert resp.json()["ollama"] == "unreachable" + + +# --------------------------------------------------------------------------- +# POST /curator/run +# --------------------------------------------------------------------------- + +class TestTriggerCurator: + def test_trigger_curator_endpoint(self, app_with_mocks): + """POST /curator/run triggers curation and returns status.""" + from fastapi.testclient import TestClient + import app.main as main_module + + vera_app, _ = app_with_mocks + + mock_curator = MagicMock() + mock_curator.run = AsyncMock() + + with patch.object(main_module, "curator", mock_curator): + with TestClient(vera_app) as client: + resp = client.post("/curator/run") + + assert resp.status_code == 200 + assert resp.json()["status"] == "curation completed" + mock_curator.run.assert_called_once() + + +# --------------------------------------------------------------------------- +# Proxy catch-all +# --------------------------------------------------------------------------- + +class TestProxyAll: + def test_non_chat_api_proxied(self, app_with_mocks): + """Non-chat API paths are proxied to Ollama.""" + from fastapi.testclient import TestClient + + vera_app, _ = app_with_mocks + + async def fake_aiter_bytes(): + yield b'{"status": "ok"}' + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.aiter_bytes = fake_aiter_bytes + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.request = AsyncMock(return_value=mock_resp) + + with patch("httpx.AsyncClient", return_value=mock_client_instance): + with TestClient(vera_app) as client: + resp = client.get("/api/show") + + assert resp.status_code == 200 diff --git a/tests/test_proxy_handler.py b/tests/test_proxy_handler.py index 44cde04..fc31c17 100644 --- a/tests/test_proxy_handler.py +++ b/tests/test_proxy_handler.py @@ -219,3 +219,94 @@ class TestHandleChatNonStreaming: # The wrapper should be stripped assert "Memory context" not in stored_question assert "What is the answer?" in stored_question + + +class TestDebugLog: + """Tests for debug_log function.""" + + def test_debug_log_writes_json_when_enabled(self, tmp_path): + """Debug log appends valid JSON line to file when debug=True.""" + import json + from unittest.mock import patch, MagicMock + + mock_config = MagicMock() + mock_config.debug = True + + with patch("app.proxy_handler.config", mock_config), \ + patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path): + from app.proxy_handler import debug_log + debug_log("test_cat", "test message", {"key": "value"}) + + log_files = list(tmp_path.glob("debug_*.log")) + assert len(log_files) == 1 + content = log_files[0].read_text().strip() + entry = json.loads(content) + assert entry["category"] == "test_cat" + assert entry["message"] == "test message" + assert entry["data"]["key"] == "value" + + def test_debug_log_skips_when_disabled(self, tmp_path): + """Debug log does nothing when debug=False.""" + from unittest.mock import patch, MagicMock + + mock_config = MagicMock() + mock_config.debug = False + + with patch("app.proxy_handler.config", mock_config), \ + patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path): + from app.proxy_handler import debug_log + debug_log("test_cat", "test message") + + log_files = list(tmp_path.glob("debug_*.log")) + assert len(log_files) == 0 + + def test_debug_log_without_data(self, tmp_path): + """Debug log works without optional data parameter.""" + import json + from unittest.mock import patch, MagicMock + + mock_config = MagicMock() + mock_config.debug = True + + with patch("app.proxy_handler.config", mock_config), \ + patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path): + from app.proxy_handler import debug_log + debug_log("simple_cat", "no data here") + + log_files = list(tmp_path.glob("debug_*.log")) + assert len(log_files) == 1 + entry = json.loads(log_files[0].read_text().strip()) + assert "data" not in entry + assert entry["category"] == "simple_cat" + + +class TestForwardToOllama: + """Tests for forward_to_ollama function.""" + + @pytest.mark.asyncio + async def test_forwards_request_to_ollama(self): + """forward_to_ollama proxies request to Ollama host.""" + from app.proxy_handler import forward_to_ollama + from unittest.mock import patch, AsyncMock, MagicMock + + mock_request = AsyncMock() + mock_request.body = AsyncMock(return_value=b'{"model": "llama3"}') + mock_request.method = "POST" + mock_request.headers = {"content-type": "application/json", "content-length": "20"} + + mock_resp = MagicMock() + mock_resp.status_code = 200 + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.request = AsyncMock(return_value=mock_resp) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await forward_to_ollama(mock_request, "/api/show") + + assert result == mock_resp + mock_client.request.assert_called_once() + call_kwargs = mock_client.request.call_args + assert call_kwargs[1]["method"] == "POST" + assert "/api/show" in call_kwargs[1]["url"] diff --git a/tests/test_qdrant_service.py b/tests/test_qdrant_service.py new file mode 100644 index 0000000..afbe963 --- /dev/null +++ b/tests/test_qdrant_service.py @@ -0,0 +1,255 @@ +"""Tests for QdrantService — all Qdrant and Ollama calls are mocked.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +def make_qdrant_service(): + """Create a QdrantService with mocked AsyncQdrantClient.""" + with patch("app.qdrant_service.AsyncQdrantClient") as MockClient: + mock_client = AsyncMock() + MockClient.return_value = mock_client + + from app.qdrant_service import QdrantService + svc = QdrantService( + host="http://localhost:6333", + collection="test_memories", + embedding_model="snowflake-arctic-embed2", + vector_size=1024, + ollama_host="http://localhost:11434", + ) + + return svc, mock_client + + +class TestEnsureCollection: + """Tests for _ensure_collection.""" + + @pytest.mark.asyncio + async def test_creates_collection_when_missing(self): + """Creates collection if it does not exist.""" + svc, mock_client = make_qdrant_service() + mock_client.get_collection = AsyncMock(side_effect=Exception("not found")) + mock_client.create_collection = AsyncMock() + + await svc._ensure_collection() + + mock_client.create_collection.assert_called_once() + assert svc._collection_ensured is True + + @pytest.mark.asyncio + async def test_skips_if_collection_exists(self): + """Does not create collection if it already exists.""" + svc, mock_client = make_qdrant_service() + mock_client.get_collection = AsyncMock(return_value=MagicMock()) + mock_client.create_collection = AsyncMock() + + await svc._ensure_collection() + + mock_client.create_collection.assert_not_called() + assert svc._collection_ensured is True + + @pytest.mark.asyncio + async def test_skips_if_already_ensured(self): + """Skips entirely if _collection_ensured is True.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + mock_client.get_collection = AsyncMock() + + await svc._ensure_collection() + + mock_client.get_collection.assert_not_called() + + +class TestGetEmbedding: + """Tests for get_embedding.""" + + @pytest.mark.asyncio + async def test_returns_embedding_vector(self): + """Returns embedding from Ollama response.""" + svc, _ = make_qdrant_service() + fake_embedding = [0.1] * 1024 + + mock_resp = MagicMock() + mock_resp.json.return_value = {"embedding": fake_embedding} + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_resp) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await svc.get_embedding("test text") + + assert result == fake_embedding + assert len(result) == 1024 + + +class TestStoreTurn: + """Tests for store_turn.""" + + @pytest.mark.asyncio + async def test_stores_raw_user_turn(self): + """Stores a user turn with proper payload.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + mock_client.upsert = AsyncMock() + + fake_embedding = [0.1] * 1024 + with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)): + point_id = await svc.store_turn(role="user", content="hello world") + + assert isinstance(point_id, str) + mock_client.upsert.assert_called_once() + call_args = mock_client.upsert.call_args + point = call_args[1]["points"][0] + assert point.payload["type"] == "raw" + assert point.payload["role"] == "user" + assert "User: hello world" in point.payload["text"] + + @pytest.mark.asyncio + async def test_stores_curated_turn(self): + """Stores a curated turn without role prefix in text.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + mock_client.upsert = AsyncMock() + + fake_embedding = [0.1] * 1024 + with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)): + point_id = await svc.store_turn( + role="curated", + content="User: q\nAssistant: a", + entry_type="curated" + ) + + call_args = mock_client.upsert.call_args + point = call_args[1]["points"][0] + assert point.payload["type"] == "curated" + # Curated text should be the content directly, not prefixed + assert point.payload["text"] == "User: q\nAssistant: a" + + @pytest.mark.asyncio + async def test_stores_with_topic_and_metadata(self): + """Stores turn with optional topic and metadata.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + mock_client.upsert = AsyncMock() + + fake_embedding = [0.1] * 1024 + with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)): + await svc.store_turn( + role="assistant", + content="some response", + topic="python", + metadata={"source": "test"} + ) + + call_args = mock_client.upsert.call_args + point = call_args[1]["points"][0] + assert point.payload["topic"] == "python" + assert point.payload["source"] == "test" + + +class TestStoreQaTurn: + """Tests for store_qa_turn.""" + + @pytest.mark.asyncio + async def test_stores_qa_turn(self): + """Stores a complete Q&A turn.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + mock_client.upsert = AsyncMock() + + fake_embedding = [0.1] * 1024 + with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)): + point_id = await svc.store_qa_turn("What is Python?", "A programming language.") + + assert isinstance(point_id, str) + call_args = mock_client.upsert.call_args + point = call_args[1]["points"][0] + assert point.payload["type"] == "raw" + assert point.payload["role"] == "qa" + assert "What is Python?" in point.payload["text"] + assert "A programming language." in point.payload["text"] + + +class TestSemanticSearch: + """Tests for semantic_search.""" + + @pytest.mark.asyncio + async def test_returns_matching_results(self): + """Returns formatted search results.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + + mock_point = MagicMock() + mock_point.id = "result-1" + mock_point.score = 0.85 + mock_point.payload = {"text": "User: hello\nAssistant: hi", "type": "curated"} + + mock_query_result = MagicMock() + mock_query_result.points = [mock_point] + mock_client.query_points = AsyncMock(return_value=mock_query_result) + + fake_embedding = [0.1] * 1024 + with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)): + results = await svc.semantic_search("hello", limit=10, score_threshold=0.6) + + assert len(results) == 1 + assert results[0]["id"] == "result-1" + assert results[0]["score"] == 0.85 + assert results[0]["payload"]["type"] == "curated" + + +class TestGetRecentTurns: + """Tests for get_recent_turns.""" + + @pytest.mark.asyncio + async def test_returns_sorted_recent_turns(self): + """Returns turns sorted by timestamp descending.""" + svc, mock_client = make_qdrant_service() + svc._collection_ensured = True + + mock_point1 = MagicMock() + mock_point1.id = "old" + mock_point1.payload = {"timestamp": "2026-01-01T00:00:00Z", "text": "old turn"} + + mock_point2 = MagicMock() + mock_point2.id = "new" + mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"} + + mock_client.scroll = AsyncMock(return_value=([mock_point1, mock_point2], None)) + + results = await svc.get_recent_turns(limit=2) + + assert len(results) == 2 + # Newest first + assert results[0]["id"] == "new" + assert results[1]["id"] == "old" + + +class TestDeletePoints: + """Tests for delete_points.""" + + @pytest.mark.asyncio + async def test_deletes_by_ids(self): + """Deletes points by their IDs.""" + svc, mock_client = make_qdrant_service() + mock_client.delete = AsyncMock() + + await svc.delete_points(["id1", "id2"]) + + mock_client.delete.assert_called_once() + + +class TestClose: + """Tests for close.""" + + @pytest.mark.asyncio + async def test_closes_client(self): + """Closes the async Qdrant client.""" + svc, mock_client = make_qdrant_service() + mock_client.close = AsyncMock() + + await svc.close() + + mock_client.close.assert_called_once() diff --git a/tests/test_utils.py b/tests/test_utils.py index 4ac17f8..e2019ae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ """Tests for utility functions.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch -from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn, build_augmented_messages +from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn, build_augmented_messages, count_messages_tokens class TestCountTokens: @@ -86,6 +86,76 @@ Assistant: Yes, very popular.""" assert "Line 3" in result[0]["content"] +class TestCountMessagesTokens: + """Tests for count_messages_tokens function.""" + + def test_empty_list(self): + """Empty message list returns 0.""" + assert count_messages_tokens([]) == 0 + + def test_single_message(self): + """Single message counts tokens of its content.""" + msgs = [{"role": "user", "content": "Hello world"}] + result = count_messages_tokens(msgs) + assert result > 0 + + def test_multiple_messages(self): + """Multiple messages sum up their token counts.""" + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there, how can I help you today?"}, + ] + result = count_messages_tokens(msgs) + assert result > count_messages_tokens([msgs[0]]) + + def test_message_without_content(self): + """Message without content field contributes 0 tokens.""" + msgs = [{"role": "system"}] + assert count_messages_tokens(msgs) == 0 + + +class TestLoadSystemPrompt: + """Tests for load_system_prompt function.""" + + def test_loads_from_prompts_dir(self, tmp_path): + """Loads systemprompt.md from PROMPTS_DIR.""" + import app.utils as utils_module + + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "systemprompt.md").write_text("You are Vera.") + + with patch.object(utils_module, "PROMPTS_DIR", prompts_dir): + result = utils_module.load_system_prompt() + + assert result == "You are Vera." + + def test_falls_back_to_static_dir(self, tmp_path): + """Falls back to STATIC_DIR when PROMPTS_DIR has no file.""" + import app.utils as utils_module + + prompts_dir = tmp_path / "no_prompts" # does not exist + static_dir = tmp_path / "static" + static_dir.mkdir() + (static_dir / "systemprompt.md").write_text("Static Vera.") + + with patch.object(utils_module, "PROMPTS_DIR", prompts_dir), \ + patch.object(utils_module, "STATIC_DIR", static_dir): + result = utils_module.load_system_prompt() + + assert result == "Static Vera." + + def test_returns_empty_when_not_found(self, tmp_path): + """Returns empty string when systemprompt.md not found anywhere.""" + import app.utils as utils_module + + with patch.object(utils_module, "PROMPTS_DIR", tmp_path / "nope"), \ + patch.object(utils_module, "STATIC_DIR", tmp_path / "also_nope"): + result = utils_module.load_system_prompt() + + assert result == "" + + class TestFilterMemoriesByTime: """Tests for filter_memories_by_time function."""