diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2f4c80e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/requirements.txt b/requirements.txt index 767784d..3510242 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ tiktoken>=0.5.0 apscheduler>=3.10.0 pytest>=7.0.0 pytest-asyncio>=0.21.0 +pytest-cov>=4.0.0 diff --git a/tests/test_config.py b/tests/test_config.py index a633cb0..97df32a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -39,4 +39,136 @@ class TestEmbeddingDims: def test_mxbai_embed_large(self): """mxbai-embed-large should have 1024 dimensions.""" - assert EMBEDDING_DIMS["mxbai-embed-large"] == 1024 \ No newline at end of file + assert EMBEDDING_DIMS["mxbai-embed-large"] == 1024 + + +class TestConfigLoad: + """Tests for Config.load() with real TOML content.""" + + def test_load_from_explicit_path(self, tmp_path): + """Config.load() should parse a TOML file at an explicit path.""" + from app.config import Config + + config_file = tmp_path / "config.toml" + config_file.write_text( + '[general]\n' + 'ollama_host = "http://localhost:11434"\n' + 'qdrant_host = "http://localhost:6333"\n' + 'qdrant_collection = "test_memories"\n' + ) + cfg = Config.load(str(config_file)) + assert cfg.ollama_host == "http://localhost:11434" + assert cfg.qdrant_host == "http://localhost:6333" + assert cfg.qdrant_collection == "test_memories" + + def test_load_layers_section(self, tmp_path): + """Config.load() should parse [layers] section correctly.""" + from app.config import Config + + config_file = tmp_path / "config.toml" + config_file.write_text( + '[layers]\n' + 'semantic_token_budget = 5000\n' + 'context_token_budget = 3000\n' + 'semantic_score_threshold = 0.75\n' + ) + cfg = Config.load(str(config_file)) + assert cfg.semantic_token_budget == 5000 + assert cfg.context_token_budget == 3000 + assert cfg.semantic_score_threshold == 0.75 + + def test_load_curator_section(self, tmp_path): + """Config.load() should parse [curator] section correctly.""" + from app.config import Config + + config_file = tmp_path / "config.toml" + config_file.write_text( + '[curator]\n' + 'run_time = "03:30"\n' + 'curator_model = "mixtral:8x22b"\n' + ) + cfg = Config.load(str(config_file)) + assert cfg.run_time == "03:30" + assert cfg.curator_model == "mixtral:8x22b" + + def test_load_cloud_section(self, tmp_path): + """Config.load() should parse [cloud] section correctly.""" + from app.config import Config + + config_file = tmp_path / "config.toml" + config_file.write_text( + '[cloud]\n' + 'enabled = true\n' + 'api_base = "https://openrouter.ai/api/v1"\n' + 'api_key_env = "MY_API_KEY"\n' + '\n' + '[cloud.models]\n' + '"gpt-oss:120b" = "openai/gpt-4o"\n' + ) + cfg = Config.load(str(config_file)) + assert cfg.cloud.enabled is True + assert cfg.cloud.api_base == "https://openrouter.ai/api/v1" + assert cfg.cloud.api_key_env == "MY_API_KEY" + assert "gpt-oss:120b" in cfg.cloud.models + + def test_load_nonexistent_file_returns_defaults(self, tmp_path): + """Config.load() with missing file should fall back to defaults.""" + from app.config import Config + import os + + # Point config dir to a place with no config.toml + os.environ["VERA_CONFIG_DIR"] = str(tmp_path / "noconfig") + try: + cfg = Config.load(str(tmp_path / "nonexistent.toml")) + finally: + del os.environ["VERA_CONFIG_DIR"] + + assert cfg.ollama_host == "http://10.0.0.10:11434" + + +class TestCloudConfig: + """Tests for CloudConfig helper methods.""" + + def test_is_cloud_model_true(self): + """is_cloud_model returns True for registered model name.""" + from app.config import CloudConfig + + cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"}) + assert cc.is_cloud_model("gpt-oss:120b") is True + + def test_is_cloud_model_false(self): + """is_cloud_model returns False for unknown model name.""" + from app.config import CloudConfig + + cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"}) + assert cc.is_cloud_model("llama3:70b") is False + + def test_get_cloud_model_existing(self): + """get_cloud_model returns mapped cloud model ID.""" + from app.config import CloudConfig + + cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"}) + assert cc.get_cloud_model("gpt-oss:120b") == "openai/gpt-4o" + + def test_get_cloud_model_missing(self): + """get_cloud_model returns None for unknown name.""" + from app.config import CloudConfig + + cc = CloudConfig(enabled=True, models={}) + assert cc.get_cloud_model("unknown") is None + + def test_api_key_from_env(self, monkeypatch): + """api_key property reads from environment variable.""" + from app.config import CloudConfig + + monkeypatch.setenv("MY_TEST_KEY", "sk-secret") + cc = CloudConfig(api_key_env="MY_TEST_KEY") + assert cc.api_key == "sk-secret" + + def test_api_key_missing_from_env(self, monkeypatch): + """api_key returns None when env var is not set.""" + from app.config import CloudConfig + + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + cc = CloudConfig(api_key_env="OPENROUTER_API_KEY") + assert cc.api_key is None \ No newline at end of file diff --git a/tests/test_curator.py b/tests/test_curator.py new file mode 100644 index 0000000..4eb679e --- /dev/null +++ b/tests/test_curator.py @@ -0,0 +1,201 @@ +"""Tests for Curator class methods — no live LLM or Qdrant required.""" +import pytest +import json +import os +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import MagicMock, patch + + +def make_curator(tmp_path): + """Return a Curator instance with a dummy prompt file and mock QdrantService.""" + from app.curator import Curator + + # Create a minimal curator_prompt.md so Curator.__init__ can load it + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "curator_prompt.md").write_text("Curate memories. Date: {CURRENT_DATE}") + + mock_qdrant = MagicMock() + + with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}): + curator = Curator( + qdrant_service=mock_qdrant, + model="test-model", + ollama_host="http://localhost:11434", + ) + + return curator, mock_qdrant + + +class TestParseJsonResponse: + """Tests for Curator._parse_json_response.""" + + def test_direct_valid_json(self, tmp_path): + """Valid JSON string parsed directly.""" + curator, _ = make_curator(tmp_path) + payload = {"new_curated_turns": [], "deletions": []} + result = curator._parse_json_response(json.dumps(payload)) + assert result == payload + + def test_json_in_code_block(self, tmp_path): + """JSON wrapped in ```json ... ``` code fence is extracted.""" + curator, _ = make_curator(tmp_path) + payload = {"summary": "done"} + response = f"```json\n{json.dumps(payload)}\n```" + result = curator._parse_json_response(response) + assert result == payload + + def test_json_embedded_in_text(self, tmp_path): + """JSON embedded after prose text is extracted via brace scan.""" + curator, _ = make_curator(tmp_path) + payload = {"new_curated_turns": [{"content": "Q: hi\nA: there"}]} + response = f"Here is the result:\n{json.dumps(payload)}\nThat's all." + result = curator._parse_json_response(response) + assert result is not None + assert "new_curated_turns" in result + + def test_empty_string_returns_none(self, tmp_path): + """Empty response returns None.""" + curator, _ = make_curator(tmp_path) + result = curator._parse_json_response("") + assert result is None + + def test_malformed_json_returns_none(self, tmp_path): + """Completely invalid text returns None.""" + curator, _ = make_curator(tmp_path) + result = curator._parse_json_response("this is not json at all !!!") + assert result is None + + def test_json_in_plain_code_block(self, tmp_path): + """JSON in ``` (no language tag) code fence is extracted.""" + curator, _ = make_curator(tmp_path) + payload = {"permanent_rules": []} + response = f"```\n{json.dumps(payload)}\n```" + result = curator._parse_json_response(response) + assert result == payload + + +class TestIsRecent: + """Tests for Curator._is_recent.""" + + def test_memory_within_window(self, tmp_path): + """Memory timestamped 1 hour ago is recent (within 24h).""" + curator, _ = make_curator(tmp_path) + ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + "Z" + memory = {"timestamp": ts} + assert curator._is_recent(memory, hours=24) is True + + def test_memory_outside_window(self, tmp_path): + """Memory timestamped 48 hours ago is not recent.""" + curator, _ = make_curator(tmp_path) + ts = (datetime.utcnow() - timedelta(hours=48)).isoformat() + "Z" + memory = {"timestamp": ts} + assert curator._is_recent(memory, hours=24) is False + + def test_no_timestamp_returns_true(self, tmp_path): + """Memory without timestamp is treated as recent (safe default).""" + curator, _ = make_curator(tmp_path) + memory = {} + assert curator._is_recent(memory, hours=24) is True + + def test_empty_timestamp_returns_true(self, tmp_path): + """Memory with empty timestamp string is treated as recent.""" + curator, _ = make_curator(tmp_path) + memory = {"timestamp": ""} + assert curator._is_recent(memory, hours=24) is True + + def test_unparseable_timestamp_returns_true(self, tmp_path): + """Memory with garbage timestamp is treated as recent (safe default).""" + curator, _ = make_curator(tmp_path) + memory = {"timestamp": "not-a-date"} + assert curator._is_recent(memory, hours=24) is True + + def test_boundary_edge_just_inside(self, tmp_path): + """Memory at exactly hours-1 minutes ago should be recent.""" + curator, _ = make_curator(tmp_path) + ts = (datetime.utcnow() - timedelta(hours=23, minutes=59)).isoformat() + "Z" + memory = {"timestamp": ts} + assert curator._is_recent(memory, hours=24) is True + + +class TestFormatRawTurns: + """Tests for Curator._format_raw_turns.""" + + def test_empty_list(self, tmp_path): + """Empty input produces empty string.""" + curator, _ = make_curator(tmp_path) + result = curator._format_raw_turns([]) + assert result == "" + + def test_single_turn_header(self, tmp_path): + """Single turn has RAW TURN 1 header and turn ID.""" + curator, _ = make_curator(tmp_path) + turns = [{"id": "abc123", "text": "User: hello\nAssistant: hi"}] + result = curator._format_raw_turns(turns) + assert "RAW TURN 1" in result + assert "abc123" in result + assert "hello" in result + + def test_multiple_turns_numbered(self, tmp_path): + """Multiple turns are numbered sequentially.""" + curator, _ = make_curator(tmp_path) + turns = [ + {"id": "id1", "text": "turn one"}, + {"id": "id2", "text": "turn two"}, + {"id": "id3", "text": "turn three"}, + ] + result = curator._format_raw_turns(turns) + assert "RAW TURN 1" in result + assert "RAW TURN 2" in result + assert "RAW TURN 3" in result + + def test_missing_id_uses_unknown(self, tmp_path): + """Turn without id field shows 'unknown' placeholder.""" + curator, _ = make_curator(tmp_path) + turns = [{"text": "some text"}] + result = curator._format_raw_turns(turns) + assert "unknown" in result + + +class TestAppendRuleToFile: + """Tests for Curator._append_rule_to_file (filesystem I/O mocked via tmp_path).""" + + @pytest.mark.asyncio + async def test_appends_to_existing_file(self, tmp_path): + """Rule is appended to existing file.""" + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + target = prompts_dir / "systemprompt.md" + target.write_text("# Existing content\n") + + (prompts_dir / "curator_prompt.md").write_text("prompt {CURRENT_DATE}") + + from app.curator import Curator + + mock_qdrant = MagicMock() + with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}): + curator = Curator(mock_qdrant, model="m", ollama_host="http://x") + await curator._append_rule_to_file("systemprompt.md", "Always be concise.") + + content = target.read_text() + assert "Always be concise." in content + assert "# Existing content" in content + + @pytest.mark.asyncio + async def test_creates_file_if_missing(self, tmp_path): + """Rule is written to a new file if none existed.""" + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "curator_prompt.md").write_text("prompt {CURRENT_DATE}") + + from app.curator import Curator + + mock_qdrant = MagicMock() + with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}): + curator = Curator(mock_qdrant, model="m", ollama_host="http://x") + await curator._append_rule_to_file("newfile.md", "New rule here.") + + target = prompts_dir / "newfile.md" + assert target.exists() + assert "New rule here." in target.read_text() diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..3b93ffc --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,351 @@ +"""Integration tests — FastAPI app via httpx.AsyncClient test transport. + +All external I/O (Ollama, Qdrant) is mocked. No live services required. +""" +import pytest +import json +import os +from unittest.mock import AsyncMock, MagicMock, patch +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_mock_qdrant(): + """Return a fully-mocked QdrantService.""" + mock = MagicMock() + mock._ensure_collection = AsyncMock() + mock.semantic_search = AsyncMock(return_value=[]) + mock.get_recent_turns = AsyncMock(return_value=[]) + mock.store_qa_turn = AsyncMock(return_value="fake-uuid") + mock.close = AsyncMock() + return mock + + +def _ollama_tags_response(): + return {"models": [{"name": "llama3", "size": 0}]} + + +def _ollama_chat_response(content: str = "Hello from Ollama"): + return { + "message": {"role": "assistant", "content": content}, + "done": True, + "model": "llama3", + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def mock_qdrant(): + return _make_mock_qdrant() + + +@pytest.fixture() +def app_with_mocks(mock_qdrant, tmp_path): + """Return the FastAPI app with lifespan mocked (no real Qdrant/scheduler).""" + from contextlib import asynccontextmanager + + # Minimal curator prompt + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "curator_prompt.md").write_text("Curate. Date: {CURRENT_DATE}") + (prompts_dir / "systemprompt.md").write_text("You are Vera.") + + @asynccontextmanager + async def fake_lifespan(app): + yield + + import app.main as main_module + + with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}), \ + patch("app.main.get_qdrant_service", return_value=mock_qdrant), \ + patch("app.singleton.get_qdrant_service", return_value=mock_qdrant), \ + patch("app.main.Curator") as MockCurator, \ + patch("app.main.scheduler") as mock_scheduler: + + mock_scheduler.add_job = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + mock_curator_instance = MagicMock() + mock_curator_instance.run = AsyncMock() + MockCurator.return_value = mock_curator_instance + + from fastapi import FastAPI + from fastapi.testclient import TestClient + + # Import fresh — use the real routes but swap lifespan + from app.main import app as vera_app + vera_app.router.lifespan_context = fake_lifespan + + yield vera_app, mock_qdrant + + +# --------------------------------------------------------------------------- +# Health check +# --------------------------------------------------------------------------- + +class TestHealthCheck: + def test_health_ollama_reachable(self, app_with_mocks): + """GET / returns status ok and ollama=reachable when Ollama is up.""" + from fastapi.testclient import TestClient + + vera_app, mock_qdrant = app_with_mocks + + mock_resp = MagicMock() + mock_resp.status_code = 200 + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.get = AsyncMock(return_value=mock_resp) + + with patch("httpx.AsyncClient", return_value=mock_client_instance): + with TestClient(vera_app, raise_server_exceptions=True) as client: + resp = client.get("/") + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["ollama"] == "reachable" + + def test_health_ollama_unreachable(self, app_with_mocks): + """GET / returns ollama=unreachable when Ollama is down.""" + import httpx + from fastapi.testclient import TestClient + + vera_app, _ = app_with_mocks + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.get = AsyncMock(side_effect=httpx.ConnectError("refused")) + + with patch("httpx.AsyncClient", return_value=mock_client_instance): + with TestClient(vera_app, raise_server_exceptions=True) as client: + resp = client.get("/") + + assert resp.status_code == 200 + assert resp.json()["ollama"] == "unreachable" + + +# --------------------------------------------------------------------------- +# /api/tags +# --------------------------------------------------------------------------- + +class TestApiTags: + def test_returns_model_list(self, app_with_mocks): + """GET /api/tags proxies Ollama tags.""" + from fastapi.testclient import TestClient + + vera_app, _ = app_with_mocks + + mock_resp = MagicMock() + mock_resp.json.return_value = _ollama_tags_response() + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.get = AsyncMock(return_value=mock_resp) + + with patch("httpx.AsyncClient", return_value=mock_client_instance): + with TestClient(vera_app) as client: + resp = client.get("/api/tags") + + assert resp.status_code == 200 + data = resp.json() + assert "models" in data + assert any(m["name"] == "llama3" for m in data["models"]) + + def test_cloud_models_injected(self, tmp_path): + """Cloud models appear in /api/tags when cloud is enabled.""" + from fastapi.testclient import TestClient + from contextlib import asynccontextmanager + + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "curator_prompt.md").write_text("Curate.") + (prompts_dir / "systemprompt.md").write_text("") + + mock_qdrant = _make_mock_qdrant() + + @asynccontextmanager + async def fake_lifespan(app): + yield + + from app.config import Config, CloudConfig + patched_config = Config() + patched_config.cloud = CloudConfig( + enabled=True, + models={"gpt-oss:120b": "openai/gpt-4o"}, + ) + + mock_resp = MagicMock() + mock_resp.json.return_value = {"models": []} + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.get = AsyncMock(return_value=mock_resp) + + import app.main as main_module + + with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}), \ + patch("app.main.config", patched_config), \ + patch("app.main.get_qdrant_service", return_value=mock_qdrant), \ + patch("app.main.scheduler") as mock_scheduler, \ + patch("app.main.Curator") as MockCurator: + + mock_scheduler.add_job = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + mock_curator_instance = MagicMock() + mock_curator_instance.run = AsyncMock() + MockCurator.return_value = mock_curator_instance + + from app.main import app as vera_app + vera_app.router.lifespan_context = fake_lifespan + + with patch("httpx.AsyncClient", return_value=mock_client_instance): + with TestClient(vera_app) as client: + resp = client.get("/api/tags") + + data = resp.json() + names = [m["name"] for m in data["models"]] + assert "gpt-oss:120b" in names + + +# --------------------------------------------------------------------------- +# POST /api/chat (non-streaming) +# --------------------------------------------------------------------------- + +class TestApiChatNonStreaming: + def test_non_streaming_round_trip(self, app_with_mocks): + """POST /api/chat with stream=False returns Ollama response.""" + from fastapi.testclient import TestClient + import app.utils as utils_module + import app.proxy_handler as ph_module + + vera_app, mock_qdrant = app_with_mocks + + ollama_data = _ollama_chat_response("The answer is 42.") + + mock_post_resp = MagicMock() + mock_post_resp.json.return_value = ollama_data + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.post = AsyncMock(return_value=mock_post_resp) + + with patch.object(utils_module, "load_system_prompt", return_value=""), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client_instance): + + with TestClient(vera_app) as client: + resp = client.post( + "/api/chat", + json={ + "model": "llama3", + "messages": [{"role": "user", "content": "What is the answer?"}], + "stream": False, + }, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["message"]["content"] == "The answer is 42." + + def test_non_streaming_stores_qa(self, app_with_mocks): + """POST /api/chat non-streaming stores the Q&A turn in Qdrant.""" + from fastapi.testclient import TestClient + import app.utils as utils_module + + vera_app, mock_qdrant = app_with_mocks + + ollama_data = _ollama_chat_response("42.") + + mock_post_resp = MagicMock() + mock_post_resp.json.return_value = ollama_data + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.post = AsyncMock(return_value=mock_post_resp) + + with patch.object(utils_module, "load_system_prompt", return_value=""), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client_instance): + + with TestClient(vera_app) as client: + client.post( + "/api/chat", + json={ + "model": "llama3", + "messages": [{"role": "user", "content": "What is 6*7?"}], + "stream": False, + }, + ) + + mock_qdrant.store_qa_turn.assert_called_once() + args = mock_qdrant.store_qa_turn.call_args[0] + assert "6*7" in args[0] + assert "42." in args[1] + + +# --------------------------------------------------------------------------- +# POST /api/chat (streaming) +# --------------------------------------------------------------------------- + +class TestApiChatStreaming: + def test_streaming_response_passthrough(self, app_with_mocks): + """POST /api/chat with stream=True streams Ollama chunks.""" + from fastapi.testclient import TestClient + import app.utils as utils_module + import app.proxy_handler as ph_module + + vera_app, mock_qdrant = app_with_mocks + + chunk1 = json.dumps({"message": {"content": "Hello"}, "done": False}).encode() + chunk2 = json.dumps({"message": {"content": " world"}, "done": True}).encode() + + async def fake_aiter_bytes(): + yield chunk1 + yield chunk2 + + mock_stream_resp = MagicMock() + mock_stream_resp.aiter_bytes = fake_aiter_bytes + mock_stream_resp.status_code = 200 + mock_stream_resp.headers = {"content-type": "application/x-ndjson"} + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_client_instance.post = AsyncMock(return_value=mock_stream_resp) + + with patch.object(utils_module, "load_system_prompt", return_value=""), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client_instance): + + with TestClient(vera_app) as client: + resp = client.post( + "/api/chat", + json={ + "model": "llama3", + "messages": [{"role": "user", "content": "Say hello"}], + "stream": True, + }, + ) + + assert resp.status_code == 200 + # Response body should contain both chunks concatenated + body_text = resp.text + assert "Hello" in body_text or len(body_text) > 0 diff --git a/tests/test_proxy_handler.py b/tests/test_proxy_handler.py new file mode 100644 index 0000000..337648e --- /dev/null +++ b/tests/test_proxy_handler.py @@ -0,0 +1,221 @@ +"""Tests for proxy_handler — no live Ollama or Qdrant required.""" +import pytest +import json +from unittest.mock import AsyncMock, MagicMock, patch + + +class TestCleanMessageContent: + """Tests for clean_message_content.""" + + def test_passthrough_plain_message(self): + """Plain text without wrapper is returned unchanged.""" + from app.proxy_handler import clean_message_content + + content = "What is the capital of France?" + assert clean_message_content(content) == content + + def test_strips_memory_context_wrapper(self): + """[Memory context] wrapper is stripped, actual user_msg returned.""" + from app.proxy_handler import clean_message_content + + content = ( + "[Memory context]\n" + "some context here\n" + "- user_msg: What is the capital of France?\n\n" + ) + result = clean_message_content(content) + assert result == "What is the capital of France?" + + def test_strips_timestamp_prefix(self): + """ISO timestamp prefix like [2024-01-01T00:00:00] is removed.""" + from app.proxy_handler import clean_message_content + + content = "[2024-01-01T12:34:56] Tell me a joke" + result = clean_message_content(content) + assert result == "Tell me a joke" + + def test_empty_string_returned_as_is(self): + """Empty string input returns empty string.""" + from app.proxy_handler import clean_message_content + + assert clean_message_content("") == "" + + def test_none_input_returned_as_is(self): + """None/falsy input is returned unchanged.""" + from app.proxy_handler import clean_message_content + + assert clean_message_content(None) is None + + def test_list_content_not_processed(self): + """Non-string content (list) is returned as-is.""" + from app.proxy_handler import clean_message_content + + # content can be a list of parts in some Ollama payloads; + # the function guards with `if not content` + # A non-empty list is truthy but the regex won't match → passthrough + content = [{"type": "text", "text": "hello"}] + result = clean_message_content(content) + assert result == content + + +class TestHandleChatNonStreaming: + """Tests for handle_chat_non_streaming — fully mocked external I/O.""" + + @pytest.mark.asyncio + async def test_returns_json_response(self): + """Should return a JSONResponse with Ollama result merged with model field.""" + from app.proxy_handler import handle_chat_non_streaming + + ollama_resp_data = { + "message": {"role": "assistant", "content": "Paris."}, + "done": True, + } + + mock_httpx_resp = MagicMock() + mock_httpx_resp.json.return_value = ollama_resp_data + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_httpx_resp) + + mock_qdrant = MagicMock() + mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid") + + augmented = [{"role": "user", "content": "What is the capital of France?"}] + + with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client): + + body = { + "model": "llama3", + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "stream": False, + } + response = await handle_chat_non_streaming(body) + + # FastAPI JSONResponse + from fastapi.responses import JSONResponse + assert isinstance(response, JSONResponse) + response_body = json.loads(response.body) + assert response_body["message"]["content"] == "Paris." + assert response_body["model"] == "llama3" + + @pytest.mark.asyncio + async def test_stores_qa_turn_when_answer_present(self): + """store_qa_turn should be called with user question and assistant answer.""" + from app.proxy_handler import handle_chat_non_streaming + + ollama_resp_data = { + "message": {"role": "assistant", "content": "Berlin."}, + "done": True, + } + + mock_httpx_resp = MagicMock() + mock_httpx_resp.json.return_value = ollama_resp_data + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_httpx_resp) + + mock_qdrant = MagicMock() + mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid") + + augmented = [{"role": "user", "content": "Capital of Germany?"}] + + with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client): + + body = { + "model": "llama3", + "messages": [{"role": "user", "content": "Capital of Germany?"}], + "stream": False, + } + await handle_chat_non_streaming(body) + + mock_qdrant.store_qa_turn.assert_called_once() + call_args = mock_qdrant.store_qa_turn.call_args + assert "Capital of Germany?" in call_args[0][0] + assert "Berlin." in call_args[0][1] + + @pytest.mark.asyncio + async def test_no_store_when_empty_answer(self): + """store_qa_turn should NOT be called when the assistant answer is empty.""" + from app.proxy_handler import handle_chat_non_streaming + + ollama_resp_data = { + "message": {"role": "assistant", "content": ""}, + "done": True, + } + + mock_httpx_resp = MagicMock() + mock_httpx_resp.json.return_value = ollama_resp_data + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_httpx_resp) + + mock_qdrant = MagicMock() + mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid") + + augmented = [{"role": "user", "content": "Hello?"}] + + with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client): + + body = { + "model": "llama3", + "messages": [{"role": "user", "content": "Hello?"}], + "stream": False, + } + await handle_chat_non_streaming(body) + + mock_qdrant.store_qa_turn.assert_not_called() + + @pytest.mark.asyncio + async def test_cleans_memory_context_from_user_message(self): + """User message with [Memory context] wrapper should be cleaned before storing.""" + from app.proxy_handler import handle_chat_non_streaming + + ollama_resp_data = { + "message": {"role": "assistant", "content": "42."}, + "done": True, + } + + mock_httpx_resp = MagicMock() + mock_httpx_resp.json.return_value = ollama_resp_data + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_httpx_resp) + + mock_qdrant = MagicMock() + mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid") + + raw_content = ( + "[Memory context]\nsome ctx\n- user_msg: What is the answer?\n\n" + ) + augmented = [{"role": "user", "content": "What is the answer?"}] + + with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \ + patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \ + patch("httpx.AsyncClient", return_value=mock_client): + + body = { + "model": "llama3", + "messages": [{"role": "user", "content": raw_content}], + "stream": False, + } + await handle_chat_non_streaming(body) + + call_args = mock_qdrant.store_qa_turn.call_args + stored_question = call_args[0][0] + # The wrapper should be stripped + assert "Memory context" not in stored_question + assert "What is the answer?" in stored_question diff --git a/tests/test_utils.py b/tests/test_utils.py index 129951f..63833e0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -82,4 +82,238 @@ Assistant: Yes, very popular.""" result = parse_curated_turn(text) assert "Line 1" in result[0]["content"] assert "Line 2" in result[0]["content"] - assert "Line 3" in result[0]["content"] \ No newline at end of file + assert "Line 3" in result[0]["content"] + + +class TestFilterMemoriesByTime: + """Tests for filter_memories_by_time function.""" + + def test_includes_recent_memory(self): + """Memory with timestamp in the last 24h should be included.""" + from datetime import datetime, timedelta + from app.utils import filter_memories_by_time + + ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + memories = [{"timestamp": ts, "text": "recent"}] + result = filter_memories_by_time(memories, hours=24) + assert len(result) == 1 + + def test_excludes_old_memory(self): + """Memory older than cutoff should be excluded.""" + from datetime import datetime, timedelta + from app.utils import filter_memories_by_time + + ts = (datetime.utcnow() - timedelta(hours=48)).isoformat() + memories = [{"timestamp": ts, "text": "old"}] + result = filter_memories_by_time(memories, hours=24) + assert len(result) == 0 + + def test_includes_memory_without_timestamp(self): + """Memory with no timestamp should always be included.""" + from app.utils import filter_memories_by_time + + memories = [{"text": "no ts"}] + result = filter_memories_by_time(memories, hours=24) + assert len(result) == 1 + + def test_includes_memory_with_bad_timestamp(self): + """Memory with unparseable timestamp should be included (safe default).""" + from app.utils import filter_memories_by_time + + memories = [{"timestamp": "not-a-date", "text": "bad ts"}] + result = filter_memories_by_time(memories, hours=24) + assert len(result) == 1 + + def test_empty_list(self): + """Empty input returns empty list.""" + from app.utils import filter_memories_by_time + + assert filter_memories_by_time([], hours=24) == [] + + def test_z_suffix_timestamp(self): + """ISO timestamp with Z suffix should be handled correctly.""" + from datetime import datetime, timedelta + from app.utils import filter_memories_by_time + + ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + "Z" + memories = [{"timestamp": ts, "text": "recent with Z"}] + result = filter_memories_by_time(memories, hours=24) + assert len(result) == 1 + + +class TestMergeMemories: + """Tests for merge_memories function.""" + + def test_empty_list(self): + """Empty list returns empty text and ids.""" + from app.utils import merge_memories + + result = merge_memories([]) + assert result == {"text": "", "ids": []} + + def test_single_memory_with_text(self): + """Single memory with text field is merged.""" + from app.utils import merge_memories + + memories = [{"id": "abc", "text": "hello world", "role": ""}] + result = merge_memories(memories) + assert "hello world" in result["text"] + assert "abc" in result["ids"] + + def test_memory_with_content_field(self): + """Memory using content field (no text) is merged.""" + from app.utils import merge_memories + + memories = [{"id": "xyz", "content": "from content field"}] + result = merge_memories(memories) + assert "from content field" in result["text"] + + def test_role_included_in_output(self): + """Role prefix should appear in merged text when role is set.""" + from app.utils import merge_memories + + memories = [{"id": "1", "text": "question", "role": "user"}] + result = merge_memories(memories) + assert "[user]:" in result["text"] + + def test_multiple_memories_joined(self): + """Multiple memories are joined with double newline.""" + from app.utils import merge_memories + + memories = [ + {"id": "1", "text": "first"}, + {"id": "2", "text": "second"}, + ] + result = merge_memories(memories) + assert "first" in result["text"] + assert "second" in result["text"] + assert len(result["ids"]) == 2 + + +class TestCalculateTokenBudget: + """Tests for calculate_token_budget function.""" + + def test_default_ratios_sum(self): + """Default ratios should sum to 1.0 (system+semantic+context).""" + from app.utils import calculate_token_budget + + result = calculate_token_budget(1000) + assert result["system"] + result["semantic"] + result["context"] == 1000 + + def test_custom_ratios(self): + """Custom ratios should produce correct proportional budgets.""" + from app.utils import calculate_token_budget + + result = calculate_token_budget( + 100, system_ratio=0.1, semantic_ratio=0.6, context_ratio=0.3 + ) + assert result["system"] == 10 + assert result["semantic"] == 60 + assert result["context"] == 30 + + def test_zero_budget(self): + """Zero total budget yields all zeros.""" + from app.utils import calculate_token_budget + + result = calculate_token_budget(0) + assert result["system"] == 0 + assert result["semantic"] == 0 + assert result["context"] == 0 + + +class TestBuildAugmentedMessages: + """Tests for build_augmented_messages function (mocked I/O).""" + + def _make_qdrant_mock(self): + """Return an AsyncMock QdrantService.""" + from unittest.mock import AsyncMock, MagicMock + + mock_qdrant = MagicMock() + mock_qdrant.semantic_search = AsyncMock(return_value=[]) + mock_qdrant.get_recent_turns = AsyncMock(return_value=[]) + return mock_qdrant + + def test_system_layer_prepended(self, monkeypatch, tmp_path): + """System prompt from file should be prepended to messages.""" + import asyncio + from unittest.mock import patch + import app.utils as utils_module + + # Write a temp system prompt + prompt_file = tmp_path / "systemprompt.md" + prompt_file.write_text("You are Vera.") + + mock_qdrant = self._make_qdrant_mock() + + with patch.object(utils_module, "load_system_prompt", return_value="You are Vera."), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant): + result = asyncio.get_event_loop().run_until_complete( + utils_module.build_augmented_messages( + [{"role": "user", "content": "Hello"}] + ) + ) + + system_msgs = [m for m in result if m["role"] == "system"] + assert len(system_msgs) == 1 + assert "You are Vera." in system_msgs[0]["content"] + + def test_incoming_user_message_preserved(self, monkeypatch): + """Incoming user message should appear in output.""" + import asyncio + from unittest.mock import patch + import app.utils as utils_module + + mock_qdrant = self._make_qdrant_mock() + + with patch.object(utils_module, "load_system_prompt", return_value=""), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant): + result = asyncio.get_event_loop().run_until_complete( + utils_module.build_augmented_messages( + [{"role": "user", "content": "What is 2+2?"}] + ) + ) + + user_msgs = [m for m in result if m.get("role") == "user"] + assert any("2+2" in m["content"] for m in user_msgs) + + def test_no_system_message_when_no_prompt(self, monkeypatch): + """No system message added when both incoming and file prompt are empty.""" + import asyncio + from unittest.mock import patch + import app.utils as utils_module + + mock_qdrant = self._make_qdrant_mock() + + with patch.object(utils_module, "load_system_prompt", return_value=""), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant): + result = asyncio.get_event_loop().run_until_complete( + utils_module.build_augmented_messages( + [{"role": "user", "content": "Hi"}] + ) + ) + + system_msgs = [m for m in result if m.get("role") == "system"] + assert len(system_msgs) == 0 + + def test_semantic_results_injected(self, monkeypatch): + """Curated memories from semantic search should appear in output.""" + import asyncio + from unittest.mock import patch, AsyncMock, MagicMock + import app.utils as utils_module + + mock_qdrant = MagicMock() + mock_qdrant.semantic_search = AsyncMock(return_value=[ + {"payload": {"text": "User: Old question?\nAssistant: Old answer."}} + ]) + mock_qdrant.get_recent_turns = AsyncMock(return_value=[]) + + with patch.object(utils_module, "load_system_prompt", return_value=""), \ + patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant): + result = asyncio.get_event_loop().run_until_complete( + utils_module.build_augmented_messages( + [{"role": "user", "content": "Tell me"}] + ) + ) + + contents = [m["content"] for m in result] + assert any("Old question" in c or "Old answer" in c for c in contents) \ No newline at end of file