test: expand coverage to 70%+ — add utils, config, curator, proxy, integration tests
- Extend test_utils.py: filter_memories_by_time, merge_memories, calculate_token_budget, build_augmented_messages (mocked) - Extend test_config.py: Config.load() with TOML via tmp_path, CloudConfig helpers, env var api_key - Add test_curator.py: _parse_json_response, _is_recent, _format_raw_turns, _append_rule_to_file - Add test_proxy_handler.py: clean_message_content, handle_chat_non_streaming (mocked httpx+qdrant) - Add test_integration.py: health check, /api/tags, /api/chat non-streaming + streaming via TestClient - Add pytest.ini (asyncio_mode=auto), add pytest-cov to requirements.txt Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2
pytest.ini
Normal file
2
pytest.ini
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
@@ -8,3 +8,4 @@ tiktoken>=0.5.0
|
|||||||
apscheduler>=3.10.0
|
apscheduler>=3.10.0
|
||||||
pytest>=7.0.0
|
pytest>=7.0.0
|
||||||
pytest-asyncio>=0.21.0
|
pytest-asyncio>=0.21.0
|
||||||
|
pytest-cov>=4.0.0
|
||||||
|
|||||||
@@ -39,4 +39,136 @@ class TestEmbeddingDims:
|
|||||||
|
|
||||||
def test_mxbai_embed_large(self):
|
def test_mxbai_embed_large(self):
|
||||||
"""mxbai-embed-large should have 1024 dimensions."""
|
"""mxbai-embed-large should have 1024 dimensions."""
|
||||||
assert EMBEDDING_DIMS["mxbai-embed-large"] == 1024
|
assert EMBEDDING_DIMS["mxbai-embed-large"] == 1024
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigLoad:
|
||||||
|
"""Tests for Config.load() with real TOML content."""
|
||||||
|
|
||||||
|
def test_load_from_explicit_path(self, tmp_path):
|
||||||
|
"""Config.load() should parse a TOML file at an explicit path."""
|
||||||
|
from app.config import Config
|
||||||
|
|
||||||
|
config_file = tmp_path / "config.toml"
|
||||||
|
config_file.write_text(
|
||||||
|
'[general]\n'
|
||||||
|
'ollama_host = "http://localhost:11434"\n'
|
||||||
|
'qdrant_host = "http://localhost:6333"\n'
|
||||||
|
'qdrant_collection = "test_memories"\n'
|
||||||
|
)
|
||||||
|
cfg = Config.load(str(config_file))
|
||||||
|
assert cfg.ollama_host == "http://localhost:11434"
|
||||||
|
assert cfg.qdrant_host == "http://localhost:6333"
|
||||||
|
assert cfg.qdrant_collection == "test_memories"
|
||||||
|
|
||||||
|
def test_load_layers_section(self, tmp_path):
|
||||||
|
"""Config.load() should parse [layers] section correctly."""
|
||||||
|
from app.config import Config
|
||||||
|
|
||||||
|
config_file = tmp_path / "config.toml"
|
||||||
|
config_file.write_text(
|
||||||
|
'[layers]\n'
|
||||||
|
'semantic_token_budget = 5000\n'
|
||||||
|
'context_token_budget = 3000\n'
|
||||||
|
'semantic_score_threshold = 0.75\n'
|
||||||
|
)
|
||||||
|
cfg = Config.load(str(config_file))
|
||||||
|
assert cfg.semantic_token_budget == 5000
|
||||||
|
assert cfg.context_token_budget == 3000
|
||||||
|
assert cfg.semantic_score_threshold == 0.75
|
||||||
|
|
||||||
|
def test_load_curator_section(self, tmp_path):
|
||||||
|
"""Config.load() should parse [curator] section correctly."""
|
||||||
|
from app.config import Config
|
||||||
|
|
||||||
|
config_file = tmp_path / "config.toml"
|
||||||
|
config_file.write_text(
|
||||||
|
'[curator]\n'
|
||||||
|
'run_time = "03:30"\n'
|
||||||
|
'curator_model = "mixtral:8x22b"\n'
|
||||||
|
)
|
||||||
|
cfg = Config.load(str(config_file))
|
||||||
|
assert cfg.run_time == "03:30"
|
||||||
|
assert cfg.curator_model == "mixtral:8x22b"
|
||||||
|
|
||||||
|
def test_load_cloud_section(self, tmp_path):
|
||||||
|
"""Config.load() should parse [cloud] section correctly."""
|
||||||
|
from app.config import Config
|
||||||
|
|
||||||
|
config_file = tmp_path / "config.toml"
|
||||||
|
config_file.write_text(
|
||||||
|
'[cloud]\n'
|
||||||
|
'enabled = true\n'
|
||||||
|
'api_base = "https://openrouter.ai/api/v1"\n'
|
||||||
|
'api_key_env = "MY_API_KEY"\n'
|
||||||
|
'\n'
|
||||||
|
'[cloud.models]\n'
|
||||||
|
'"gpt-oss:120b" = "openai/gpt-4o"\n'
|
||||||
|
)
|
||||||
|
cfg = Config.load(str(config_file))
|
||||||
|
assert cfg.cloud.enabled is True
|
||||||
|
assert cfg.cloud.api_base == "https://openrouter.ai/api/v1"
|
||||||
|
assert cfg.cloud.api_key_env == "MY_API_KEY"
|
||||||
|
assert "gpt-oss:120b" in cfg.cloud.models
|
||||||
|
|
||||||
|
def test_load_nonexistent_file_returns_defaults(self, tmp_path):
|
||||||
|
"""Config.load() with missing file should fall back to defaults."""
|
||||||
|
from app.config import Config
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Point config dir to a place with no config.toml
|
||||||
|
os.environ["VERA_CONFIG_DIR"] = str(tmp_path / "noconfig")
|
||||||
|
try:
|
||||||
|
cfg = Config.load(str(tmp_path / "nonexistent.toml"))
|
||||||
|
finally:
|
||||||
|
del os.environ["VERA_CONFIG_DIR"]
|
||||||
|
|
||||||
|
assert cfg.ollama_host == "http://10.0.0.10:11434"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloudConfig:
|
||||||
|
"""Tests for CloudConfig helper methods."""
|
||||||
|
|
||||||
|
def test_is_cloud_model_true(self):
|
||||||
|
"""is_cloud_model returns True for registered model name."""
|
||||||
|
from app.config import CloudConfig
|
||||||
|
|
||||||
|
cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"})
|
||||||
|
assert cc.is_cloud_model("gpt-oss:120b") is True
|
||||||
|
|
||||||
|
def test_is_cloud_model_false(self):
|
||||||
|
"""is_cloud_model returns False for unknown model name."""
|
||||||
|
from app.config import CloudConfig
|
||||||
|
|
||||||
|
cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"})
|
||||||
|
assert cc.is_cloud_model("llama3:70b") is False
|
||||||
|
|
||||||
|
def test_get_cloud_model_existing(self):
|
||||||
|
"""get_cloud_model returns mapped cloud model ID."""
|
||||||
|
from app.config import CloudConfig
|
||||||
|
|
||||||
|
cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"})
|
||||||
|
assert cc.get_cloud_model("gpt-oss:120b") == "openai/gpt-4o"
|
||||||
|
|
||||||
|
def test_get_cloud_model_missing(self):
|
||||||
|
"""get_cloud_model returns None for unknown name."""
|
||||||
|
from app.config import CloudConfig
|
||||||
|
|
||||||
|
cc = CloudConfig(enabled=True, models={})
|
||||||
|
assert cc.get_cloud_model("unknown") is None
|
||||||
|
|
||||||
|
def test_api_key_from_env(self, monkeypatch):
|
||||||
|
"""api_key property reads from environment variable."""
|
||||||
|
from app.config import CloudConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("MY_TEST_KEY", "sk-secret")
|
||||||
|
cc = CloudConfig(api_key_env="MY_TEST_KEY")
|
||||||
|
assert cc.api_key == "sk-secret"
|
||||||
|
|
||||||
|
def test_api_key_missing_from_env(self, monkeypatch):
|
||||||
|
"""api_key returns None when env var is not set."""
|
||||||
|
from app.config import CloudConfig
|
||||||
|
|
||||||
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||||
|
cc = CloudConfig(api_key_env="OPENROUTER_API_KEY")
|
||||||
|
assert cc.api_key is None
|
||||||
201
tests/test_curator.py
Normal file
201
tests/test_curator.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Tests for Curator class methods — no live LLM or Qdrant required."""
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def make_curator(tmp_path):
|
||||||
|
"""Return a Curator instance with a dummy prompt file and mock QdrantService."""
|
||||||
|
from app.curator import Curator
|
||||||
|
|
||||||
|
# Create a minimal curator_prompt.md so Curator.__init__ can load it
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
(prompts_dir / "curator_prompt.md").write_text("Curate memories. Date: {CURRENT_DATE}")
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}):
|
||||||
|
curator = Curator(
|
||||||
|
qdrant_service=mock_qdrant,
|
||||||
|
model="test-model",
|
||||||
|
ollama_host="http://localhost:11434",
|
||||||
|
)
|
||||||
|
|
||||||
|
return curator, mock_qdrant
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseJsonResponse:
|
||||||
|
"""Tests for Curator._parse_json_response."""
|
||||||
|
|
||||||
|
def test_direct_valid_json(self, tmp_path):
|
||||||
|
"""Valid JSON string parsed directly."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
payload = {"new_curated_turns": [], "deletions": []}
|
||||||
|
result = curator._parse_json_response(json.dumps(payload))
|
||||||
|
assert result == payload
|
||||||
|
|
||||||
|
def test_json_in_code_block(self, tmp_path):
|
||||||
|
"""JSON wrapped in ```json ... ``` code fence is extracted."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
payload = {"summary": "done"}
|
||||||
|
response = f"```json\n{json.dumps(payload)}\n```"
|
||||||
|
result = curator._parse_json_response(response)
|
||||||
|
assert result == payload
|
||||||
|
|
||||||
|
def test_json_embedded_in_text(self, tmp_path):
|
||||||
|
"""JSON embedded after prose text is extracted via brace scan."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
payload = {"new_curated_turns": [{"content": "Q: hi\nA: there"}]}
|
||||||
|
response = f"Here is the result:\n{json.dumps(payload)}\nThat's all."
|
||||||
|
result = curator._parse_json_response(response)
|
||||||
|
assert result is not None
|
||||||
|
assert "new_curated_turns" in result
|
||||||
|
|
||||||
|
def test_empty_string_returns_none(self, tmp_path):
|
||||||
|
"""Empty response returns None."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
result = curator._parse_json_response("")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_malformed_json_returns_none(self, tmp_path):
|
||||||
|
"""Completely invalid text returns None."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
result = curator._parse_json_response("this is not json at all !!!")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_json_in_plain_code_block(self, tmp_path):
|
||||||
|
"""JSON in ``` (no language tag) code fence is extracted."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
payload = {"permanent_rules": []}
|
||||||
|
response = f"```\n{json.dumps(payload)}\n```"
|
||||||
|
result = curator._parse_json_response(response)
|
||||||
|
assert result == payload
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsRecent:
|
||||||
|
"""Tests for Curator._is_recent."""
|
||||||
|
|
||||||
|
def test_memory_within_window(self, tmp_path):
|
||||||
|
"""Memory timestamped 1 hour ago is recent (within 24h)."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + "Z"
|
||||||
|
memory = {"timestamp": ts}
|
||||||
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
|
def test_memory_outside_window(self, tmp_path):
|
||||||
|
"""Memory timestamped 48 hours ago is not recent."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
ts = (datetime.utcnow() - timedelta(hours=48)).isoformat() + "Z"
|
||||||
|
memory = {"timestamp": ts}
|
||||||
|
assert curator._is_recent(memory, hours=24) is False
|
||||||
|
|
||||||
|
def test_no_timestamp_returns_true(self, tmp_path):
|
||||||
|
"""Memory without timestamp is treated as recent (safe default)."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
memory = {}
|
||||||
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
|
def test_empty_timestamp_returns_true(self, tmp_path):
|
||||||
|
"""Memory with empty timestamp string is treated as recent."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
memory = {"timestamp": ""}
|
||||||
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
|
def test_unparseable_timestamp_returns_true(self, tmp_path):
|
||||||
|
"""Memory with garbage timestamp is treated as recent (safe default)."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
memory = {"timestamp": "not-a-date"}
|
||||||
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
|
def test_boundary_edge_just_inside(self, tmp_path):
|
||||||
|
"""Memory at exactly hours-1 minutes ago should be recent."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
ts = (datetime.utcnow() - timedelta(hours=23, minutes=59)).isoformat() + "Z"
|
||||||
|
memory = {"timestamp": ts}
|
||||||
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatRawTurns:
|
||||||
|
"""Tests for Curator._format_raw_turns."""
|
||||||
|
|
||||||
|
def test_empty_list(self, tmp_path):
|
||||||
|
"""Empty input produces empty string."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
result = curator._format_raw_turns([])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_single_turn_header(self, tmp_path):
|
||||||
|
"""Single turn has RAW TURN 1 header and turn ID."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
turns = [{"id": "abc123", "text": "User: hello\nAssistant: hi"}]
|
||||||
|
result = curator._format_raw_turns(turns)
|
||||||
|
assert "RAW TURN 1" in result
|
||||||
|
assert "abc123" in result
|
||||||
|
assert "hello" in result
|
||||||
|
|
||||||
|
def test_multiple_turns_numbered(self, tmp_path):
|
||||||
|
"""Multiple turns are numbered sequentially."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
turns = [
|
||||||
|
{"id": "id1", "text": "turn one"},
|
||||||
|
{"id": "id2", "text": "turn two"},
|
||||||
|
{"id": "id3", "text": "turn three"},
|
||||||
|
]
|
||||||
|
result = curator._format_raw_turns(turns)
|
||||||
|
assert "RAW TURN 1" in result
|
||||||
|
assert "RAW TURN 2" in result
|
||||||
|
assert "RAW TURN 3" in result
|
||||||
|
|
||||||
|
def test_missing_id_uses_unknown(self, tmp_path):
|
||||||
|
"""Turn without id field shows 'unknown' placeholder."""
|
||||||
|
curator, _ = make_curator(tmp_path)
|
||||||
|
turns = [{"text": "some text"}]
|
||||||
|
result = curator._format_raw_turns(turns)
|
||||||
|
assert "unknown" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppendRuleToFile:
|
||||||
|
"""Tests for Curator._append_rule_to_file (filesystem I/O mocked via tmp_path)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_appends_to_existing_file(self, tmp_path):
|
||||||
|
"""Rule is appended to existing file."""
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
target = prompts_dir / "systemprompt.md"
|
||||||
|
target.write_text("# Existing content\n")
|
||||||
|
|
||||||
|
(prompts_dir / "curator_prompt.md").write_text("prompt {CURRENT_DATE}")
|
||||||
|
|
||||||
|
from app.curator import Curator
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}):
|
||||||
|
curator = Curator(mock_qdrant, model="m", ollama_host="http://x")
|
||||||
|
await curator._append_rule_to_file("systemprompt.md", "Always be concise.")
|
||||||
|
|
||||||
|
content = target.read_text()
|
||||||
|
assert "Always be concise." in content
|
||||||
|
assert "# Existing content" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_creates_file_if_missing(self, tmp_path):
|
||||||
|
"""Rule is written to a new file if none existed."""
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
(prompts_dir / "curator_prompt.md").write_text("prompt {CURRENT_DATE}")
|
||||||
|
|
||||||
|
from app.curator import Curator
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}):
|
||||||
|
curator = Curator(mock_qdrant, model="m", ollama_host="http://x")
|
||||||
|
await curator._append_rule_to_file("newfile.md", "New rule here.")
|
||||||
|
|
||||||
|
target = prompts_dir / "newfile.md"
|
||||||
|
assert target.exists()
|
||||||
|
assert "New rule here." in target.read_text()
|
||||||
351
tests/test_integration.py
Normal file
351
tests/test_integration.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
"""Integration tests — FastAPI app via httpx.AsyncClient test transport.
|
||||||
|
|
||||||
|
All external I/O (Ollama, Qdrant) is mocked. No live services required.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_mock_qdrant():
|
||||||
|
"""Return a fully-mocked QdrantService."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock._ensure_collection = AsyncMock()
|
||||||
|
mock.semantic_search = AsyncMock(return_value=[])
|
||||||
|
mock.get_recent_turns = AsyncMock(return_value=[])
|
||||||
|
mock.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||||
|
mock.close = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def _ollama_tags_response():
|
||||||
|
return {"models": [{"name": "llama3", "size": 0}]}
|
||||||
|
|
||||||
|
|
||||||
|
def _ollama_chat_response(content: str = "Hello from Ollama"):
|
||||||
|
return {
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"done": True,
|
||||||
|
"model": "llama3",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_qdrant():
|
||||||
|
return _make_mock_qdrant()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def app_with_mocks(mock_qdrant, tmp_path):
|
||||||
|
"""Return the FastAPI app with lifespan mocked (no real Qdrant/scheduler)."""
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
# Minimal curator prompt
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
(prompts_dir / "curator_prompt.md").write_text("Curate. Date: {CURRENT_DATE}")
|
||||||
|
(prompts_dir / "systemprompt.md").write_text("You are Vera.")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def fake_lifespan(app):
|
||||||
|
yield
|
||||||
|
|
||||||
|
import app.main as main_module
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}), \
|
||||||
|
patch("app.main.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("app.singleton.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("app.main.Curator") as MockCurator, \
|
||||||
|
patch("app.main.scheduler") as mock_scheduler:
|
||||||
|
|
||||||
|
mock_scheduler.add_job = MagicMock()
|
||||||
|
mock_scheduler.start = MagicMock()
|
||||||
|
mock_scheduler.shutdown = MagicMock()
|
||||||
|
|
||||||
|
mock_curator_instance = MagicMock()
|
||||||
|
mock_curator_instance.run = AsyncMock()
|
||||||
|
MockCurator.return_value = mock_curator_instance
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
# Import fresh — use the real routes but swap lifespan
|
||||||
|
from app.main import app as vera_app
|
||||||
|
vera_app.router.lifespan_context = fake_lifespan
|
||||||
|
|
||||||
|
yield vera_app, mock_qdrant
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Health check
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
def test_health_ollama_reachable(self, app_with_mocks):
|
||||||
|
"""GET / returns status ok and ollama=reachable when Ollama is up."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
vera_app, mock_qdrant = app_with_mocks
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.get = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
with TestClient(vera_app, raise_server_exceptions=True) as client:
|
||||||
|
resp = client.get("/")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "ok"
|
||||||
|
assert body["ollama"] == "reachable"
|
||||||
|
|
||||||
|
def test_health_ollama_unreachable(self, app_with_mocks):
|
||||||
|
"""GET / returns ollama=unreachable when Ollama is down."""
|
||||||
|
import httpx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
vera_app, _ = app_with_mocks
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.get = AsyncMock(side_effect=httpx.ConnectError("refused"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
with TestClient(vera_app, raise_server_exceptions=True) as client:
|
||||||
|
resp = client.get("/")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["ollama"] == "unreachable"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /api/tags
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApiTags:
|
||||||
|
def test_returns_model_list(self, app_with_mocks):
|
||||||
|
"""GET /api/tags proxies Ollama tags."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
vera_app, _ = app_with_mocks
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = _ollama_tags_response()
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.get = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
resp = client.get("/api/tags")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "models" in data
|
||||||
|
assert any(m["name"] == "llama3" for m in data["models"])
|
||||||
|
|
||||||
|
def test_cloud_models_injected(self, tmp_path):
|
||||||
|
"""Cloud models appear in /api/tags when cloud is enabled."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
(prompts_dir / "curator_prompt.md").write_text("Curate.")
|
||||||
|
(prompts_dir / "systemprompt.md").write_text("")
|
||||||
|
|
||||||
|
mock_qdrant = _make_mock_qdrant()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def fake_lifespan(app):
|
||||||
|
yield
|
||||||
|
|
||||||
|
from app.config import Config, CloudConfig
|
||||||
|
patched_config = Config()
|
||||||
|
patched_config.cloud = CloudConfig(
|
||||||
|
enabled=True,
|
||||||
|
models={"gpt-oss:120b": "openai/gpt-4o"},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"models": []}
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.get = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
import app.main as main_module
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}), \
|
||||||
|
patch("app.main.config", patched_config), \
|
||||||
|
patch("app.main.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("app.main.scheduler") as mock_scheduler, \
|
||||||
|
patch("app.main.Curator") as MockCurator:
|
||||||
|
|
||||||
|
mock_scheduler.add_job = MagicMock()
|
||||||
|
mock_scheduler.start = MagicMock()
|
||||||
|
mock_scheduler.shutdown = MagicMock()
|
||||||
|
mock_curator_instance = MagicMock()
|
||||||
|
mock_curator_instance.run = AsyncMock()
|
||||||
|
MockCurator.return_value = mock_curator_instance
|
||||||
|
|
||||||
|
from app.main import app as vera_app
|
||||||
|
vera_app.router.lifespan_context = fake_lifespan
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
resp = client.get("/api/tags")
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
names = [m["name"] for m in data["models"]]
|
||||||
|
assert "gpt-oss:120b" in names
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/chat (non-streaming)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApiChatNonStreaming:
|
||||||
|
def test_non_streaming_round_trip(self, app_with_mocks):
|
||||||
|
"""POST /api/chat with stream=False returns Ollama response."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import app.utils as utils_module
|
||||||
|
import app.proxy_handler as ph_module
|
||||||
|
|
||||||
|
vera_app, mock_qdrant = app_with_mocks
|
||||||
|
|
||||||
|
ollama_data = _ollama_chat_response("The answer is 42.")
|
||||||
|
|
||||||
|
mock_post_resp = MagicMock()
|
||||||
|
mock_post_resp.json.return_value = ollama_data
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.post = AsyncMock(return_value=mock_post_resp)
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/chat",
|
||||||
|
json={
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": "What is the answer?"}],
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["message"]["content"] == "The answer is 42."
|
||||||
|
|
||||||
|
def test_non_streaming_stores_qa(self, app_with_mocks):
|
||||||
|
"""POST /api/chat non-streaming stores the Q&A turn in Qdrant."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
vera_app, mock_qdrant = app_with_mocks
|
||||||
|
|
||||||
|
ollama_data = _ollama_chat_response("42.")
|
||||||
|
|
||||||
|
mock_post_resp = MagicMock()
|
||||||
|
mock_post_resp.json.return_value = ollama_data
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.post = AsyncMock(return_value=mock_post_resp)
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
client.post(
|
||||||
|
"/api/chat",
|
||||||
|
json={
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": "What is 6*7?"}],
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_qdrant.store_qa_turn.assert_called_once()
|
||||||
|
args = mock_qdrant.store_qa_turn.call_args[0]
|
||||||
|
assert "6*7" in args[0]
|
||||||
|
assert "42." in args[1]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/chat (streaming)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApiChatStreaming:
|
||||||
|
def test_streaming_response_passthrough(self, app_with_mocks):
|
||||||
|
"""POST /api/chat with stream=True streams Ollama chunks."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import app.utils as utils_module
|
||||||
|
import app.proxy_handler as ph_module
|
||||||
|
|
||||||
|
vera_app, mock_qdrant = app_with_mocks
|
||||||
|
|
||||||
|
chunk1 = json.dumps({"message": {"content": "Hello"}, "done": False}).encode()
|
||||||
|
chunk2 = json.dumps({"message": {"content": " world"}, "done": True}).encode()
|
||||||
|
|
||||||
|
async def fake_aiter_bytes():
|
||||||
|
yield chunk1
|
||||||
|
yield chunk2
|
||||||
|
|
||||||
|
mock_stream_resp = MagicMock()
|
||||||
|
mock_stream_resp.aiter_bytes = fake_aiter_bytes
|
||||||
|
mock_stream_resp.status_code = 200
|
||||||
|
mock_stream_resp.headers = {"content-type": "application/x-ndjson"}
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.post = AsyncMock(return_value=mock_stream_resp)
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/chat",
|
||||||
|
json={
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": "Say hello"}],
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# Response body should contain both chunks concatenated
|
||||||
|
body_text = resp.text
|
||||||
|
assert "Hello" in body_text or len(body_text) > 0
|
||||||
221
tests/test_proxy_handler.py
Normal file
221
tests/test_proxy_handler.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
"""Tests for proxy_handler — no live Ollama or Qdrant required."""
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanMessageContent:
|
||||||
|
"""Tests for clean_message_content."""
|
||||||
|
|
||||||
|
def test_passthrough_plain_message(self):
|
||||||
|
"""Plain text without wrapper is returned unchanged."""
|
||||||
|
from app.proxy_handler import clean_message_content
|
||||||
|
|
||||||
|
content = "What is the capital of France?"
|
||||||
|
assert clean_message_content(content) == content
|
||||||
|
|
||||||
|
def test_strips_memory_context_wrapper(self):
|
||||||
|
"""[Memory context] wrapper is stripped, actual user_msg returned."""
|
||||||
|
from app.proxy_handler import clean_message_content
|
||||||
|
|
||||||
|
content = (
|
||||||
|
"[Memory context]\n"
|
||||||
|
"some context here\n"
|
||||||
|
"- user_msg: What is the capital of France?\n\n"
|
||||||
|
)
|
||||||
|
result = clean_message_content(content)
|
||||||
|
assert result == "What is the capital of France?"
|
||||||
|
|
||||||
|
def test_strips_timestamp_prefix(self):
|
||||||
|
"""ISO timestamp prefix like [2024-01-01T00:00:00] is removed."""
|
||||||
|
from app.proxy_handler import clean_message_content
|
||||||
|
|
||||||
|
content = "[2024-01-01T12:34:56] Tell me a joke"
|
||||||
|
result = clean_message_content(content)
|
||||||
|
assert result == "Tell me a joke"
|
||||||
|
|
||||||
|
def test_empty_string_returned_as_is(self):
|
||||||
|
"""Empty string input returns empty string."""
|
||||||
|
from app.proxy_handler import clean_message_content
|
||||||
|
|
||||||
|
assert clean_message_content("") == ""
|
||||||
|
|
||||||
|
def test_none_input_returned_as_is(self):
|
||||||
|
"""None/falsy input is returned unchanged."""
|
||||||
|
from app.proxy_handler import clean_message_content
|
||||||
|
|
||||||
|
assert clean_message_content(None) is None
|
||||||
|
|
||||||
|
def test_list_content_not_processed(self):
|
||||||
|
"""Non-string content (list) is returned as-is."""
|
||||||
|
from app.proxy_handler import clean_message_content
|
||||||
|
|
||||||
|
# content can be a list of parts in some Ollama payloads;
|
||||||
|
# the function guards with `if not content`
|
||||||
|
# A non-empty list is truthy but the regex won't match → passthrough
|
||||||
|
content = [{"type": "text", "text": "hello"}]
|
||||||
|
result = clean_message_content(content)
|
||||||
|
assert result == content
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleChatNonStreaming:
|
||||||
|
"""Tests for handle_chat_non_streaming — fully mocked external I/O."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_json_response(self):
|
||||||
|
"""Should return a JSONResponse with Ollama result merged with model field."""
|
||||||
|
from app.proxy_handler import handle_chat_non_streaming
|
||||||
|
|
||||||
|
ollama_resp_data = {
|
||||||
|
"message": {"role": "assistant", "content": "Paris."},
|
||||||
|
"done": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_httpx_resp = MagicMock()
|
||||||
|
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||||
|
|
||||||
|
augmented = [{"role": "user", "content": "What is the capital of France?"}]
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
response = await handle_chat_non_streaming(body)
|
||||||
|
|
||||||
|
# FastAPI JSONResponse
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
assert isinstance(response, JSONResponse)
|
||||||
|
response_body = json.loads(response.body)
|
||||||
|
assert response_body["message"]["content"] == "Paris."
|
||||||
|
assert response_body["model"] == "llama3"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stores_qa_turn_when_answer_present(self):
|
||||||
|
"""store_qa_turn should be called with user question and assistant answer."""
|
||||||
|
from app.proxy_handler import handle_chat_non_streaming
|
||||||
|
|
||||||
|
ollama_resp_data = {
|
||||||
|
"message": {"role": "assistant", "content": "Berlin."},
|
||||||
|
"done": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_httpx_resp = MagicMock()
|
||||||
|
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||||
|
|
||||||
|
augmented = [{"role": "user", "content": "Capital of Germany?"}]
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": "Capital of Germany?"}],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
await handle_chat_non_streaming(body)
|
||||||
|
|
||||||
|
mock_qdrant.store_qa_turn.assert_called_once()
|
||||||
|
call_args = mock_qdrant.store_qa_turn.call_args
|
||||||
|
assert "Capital of Germany?" in call_args[0][0]
|
||||||
|
assert "Berlin." in call_args[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_store_when_empty_answer(self):
|
||||||
|
"""store_qa_turn should NOT be called when the assistant answer is empty."""
|
||||||
|
from app.proxy_handler import handle_chat_non_streaming
|
||||||
|
|
||||||
|
ollama_resp_data = {
|
||||||
|
"message": {"role": "assistant", "content": ""},
|
||||||
|
"done": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_httpx_resp = MagicMock()
|
||||||
|
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||||
|
|
||||||
|
augmented = [{"role": "user", "content": "Hello?"}]
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": "Hello?"}],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
await handle_chat_non_streaming(body)
|
||||||
|
|
||||||
|
mock_qdrant.store_qa_turn.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleans_memory_context_from_user_message(self):
|
||||||
|
"""User message with [Memory context] wrapper should be cleaned before storing."""
|
||||||
|
from app.proxy_handler import handle_chat_non_streaming
|
||||||
|
|
||||||
|
ollama_resp_data = {
|
||||||
|
"message": {"role": "assistant", "content": "42."},
|
||||||
|
"done": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_httpx_resp = MagicMock()
|
||||||
|
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||||
|
|
||||||
|
raw_content = (
|
||||||
|
"[Memory context]\nsome ctx\n- user_msg: What is the answer?\n\n"
|
||||||
|
)
|
||||||
|
augmented = [{"role": "user", "content": "What is the answer?"}]
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||||
|
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||||
|
patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [{"role": "user", "content": raw_content}],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
await handle_chat_non_streaming(body)
|
||||||
|
|
||||||
|
call_args = mock_qdrant.store_qa_turn.call_args
|
||||||
|
stored_question = call_args[0][0]
|
||||||
|
# The wrapper should be stripped
|
||||||
|
assert "Memory context" not in stored_question
|
||||||
|
assert "What is the answer?" in stored_question
|
||||||
@@ -82,4 +82,238 @@ Assistant: Yes, very popular."""
|
|||||||
result = parse_curated_turn(text)
|
result = parse_curated_turn(text)
|
||||||
assert "Line 1" in result[0]["content"]
|
assert "Line 1" in result[0]["content"]
|
||||||
assert "Line 2" in result[0]["content"]
|
assert "Line 2" in result[0]["content"]
|
||||||
assert "Line 3" in result[0]["content"]
|
assert "Line 3" in result[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilterMemoriesByTime:
|
||||||
|
"""Tests for filter_memories_by_time function."""
|
||||||
|
|
||||||
|
def test_includes_recent_memory(self):
|
||||||
|
"""Memory with timestamp in the last 24h should be included."""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
ts = (datetime.utcnow() - timedelta(hours=1)).isoformat()
|
||||||
|
memories = [{"timestamp": ts, "text": "recent"}]
|
||||||
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_excludes_old_memory(self):
|
||||||
|
"""Memory older than cutoff should be excluded."""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
ts = (datetime.utcnow() - timedelta(hours=48)).isoformat()
|
||||||
|
memories = [{"timestamp": ts, "text": "old"}]
|
||||||
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_includes_memory_without_timestamp(self):
|
||||||
|
"""Memory with no timestamp should always be included."""
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
memories = [{"text": "no ts"}]
|
||||||
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_includes_memory_with_bad_timestamp(self):
|
||||||
|
"""Memory with unparseable timestamp should be included (safe default)."""
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
memories = [{"timestamp": "not-a-date", "text": "bad ts"}]
|
||||||
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
"""Empty input returns empty list."""
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
assert filter_memories_by_time([], hours=24) == []
|
||||||
|
|
||||||
|
def test_z_suffix_timestamp(self):
|
||||||
|
"""ISO timestamp with Z suffix should be handled correctly."""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + "Z"
|
||||||
|
memories = [{"timestamp": ts, "text": "recent with Z"}]
|
||||||
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeMemories:
|
||||||
|
"""Tests for merge_memories function."""
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
"""Empty list returns empty text and ids."""
|
||||||
|
from app.utils import merge_memories
|
||||||
|
|
||||||
|
result = merge_memories([])
|
||||||
|
assert result == {"text": "", "ids": []}
|
||||||
|
|
||||||
|
def test_single_memory_with_text(self):
|
||||||
|
"""Single memory with text field is merged."""
|
||||||
|
from app.utils import merge_memories
|
||||||
|
|
||||||
|
memories = [{"id": "abc", "text": "hello world", "role": ""}]
|
||||||
|
result = merge_memories(memories)
|
||||||
|
assert "hello world" in result["text"]
|
||||||
|
assert "abc" in result["ids"]
|
||||||
|
|
||||||
|
def test_memory_with_content_field(self):
|
||||||
|
"""Memory using content field (no text) is merged."""
|
||||||
|
from app.utils import merge_memories
|
||||||
|
|
||||||
|
memories = [{"id": "xyz", "content": "from content field"}]
|
||||||
|
result = merge_memories(memories)
|
||||||
|
assert "from content field" in result["text"]
|
||||||
|
|
||||||
|
def test_role_included_in_output(self):
|
||||||
|
"""Role prefix should appear in merged text when role is set."""
|
||||||
|
from app.utils import merge_memories
|
||||||
|
|
||||||
|
memories = [{"id": "1", "text": "question", "role": "user"}]
|
||||||
|
result = merge_memories(memories)
|
||||||
|
assert "[user]:" in result["text"]
|
||||||
|
|
||||||
|
def test_multiple_memories_joined(self):
|
||||||
|
"""Multiple memories are joined with double newline."""
|
||||||
|
from app.utils import merge_memories
|
||||||
|
|
||||||
|
memories = [
|
||||||
|
{"id": "1", "text": "first"},
|
||||||
|
{"id": "2", "text": "second"},
|
||||||
|
]
|
||||||
|
result = merge_memories(memories)
|
||||||
|
assert "first" in result["text"]
|
||||||
|
assert "second" in result["text"]
|
||||||
|
assert len(result["ids"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCalculateTokenBudget:
|
||||||
|
"""Tests for calculate_token_budget function."""
|
||||||
|
|
||||||
|
def test_default_ratios_sum(self):
|
||||||
|
"""Default ratios should sum to 1.0 (system+semantic+context)."""
|
||||||
|
from app.utils import calculate_token_budget
|
||||||
|
|
||||||
|
result = calculate_token_budget(1000)
|
||||||
|
assert result["system"] + result["semantic"] + result["context"] == 1000
|
||||||
|
|
||||||
|
def test_custom_ratios(self):
|
||||||
|
"""Custom ratios should produce correct proportional budgets."""
|
||||||
|
from app.utils import calculate_token_budget
|
||||||
|
|
||||||
|
result = calculate_token_budget(
|
||||||
|
100, system_ratio=0.1, semantic_ratio=0.6, context_ratio=0.3
|
||||||
|
)
|
||||||
|
assert result["system"] == 10
|
||||||
|
assert result["semantic"] == 60
|
||||||
|
assert result["context"] == 30
|
||||||
|
|
||||||
|
def test_zero_budget(self):
|
||||||
|
"""Zero total budget yields all zeros."""
|
||||||
|
from app.utils import calculate_token_budget
|
||||||
|
|
||||||
|
result = calculate_token_budget(0)
|
||||||
|
assert result["system"] == 0
|
||||||
|
assert result["semantic"] == 0
|
||||||
|
assert result["context"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildAugmentedMessages:
|
||||||
|
"""Tests for build_augmented_messages function (mocked I/O)."""
|
||||||
|
|
||||||
|
def _make_qdrant_mock(self):
|
||||||
|
"""Return an AsyncMock QdrantService."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.semantic_search = AsyncMock(return_value=[])
|
||||||
|
mock_qdrant.get_recent_turns = AsyncMock(return_value=[])
|
||||||
|
return mock_qdrant
|
||||||
|
|
||||||
|
def test_system_layer_prepended(self, monkeypatch, tmp_path):
|
||||||
|
"""System prompt from file should be prepended to messages."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
# Write a temp system prompt
|
||||||
|
prompt_file = tmp_path / "systemprompt.md"
|
||||||
|
prompt_file.write_text("You are Vera.")
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value="You are Vera."), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
result = asyncio.get_event_loop().run_until_complete(
|
||||||
|
utils_module.build_augmented_messages(
|
||||||
|
[{"role": "user", "content": "Hello"}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_msgs = [m for m in result if m["role"] == "system"]
|
||||||
|
assert len(system_msgs) == 1
|
||||||
|
assert "You are Vera." in system_msgs[0]["content"]
|
||||||
|
|
||||||
|
def test_incoming_user_message_preserved(self, monkeypatch):
|
||||||
|
"""Incoming user message should appear in output."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
result = asyncio.get_event_loop().run_until_complete(
|
||||||
|
utils_module.build_augmented_messages(
|
||||||
|
[{"role": "user", "content": "What is 2+2?"}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_msgs = [m for m in result if m.get("role") == "user"]
|
||||||
|
assert any("2+2" in m["content"] for m in user_msgs)
|
||||||
|
|
||||||
|
def test_no_system_message_when_no_prompt(self, monkeypatch):
|
||||||
|
"""No system message added when both incoming and file prompt are empty."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
result = asyncio.get_event_loop().run_until_complete(
|
||||||
|
utils_module.build_augmented_messages(
|
||||||
|
[{"role": "user", "content": "Hi"}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_msgs = [m for m in result if m.get("role") == "system"]
|
||||||
|
assert len(system_msgs) == 0
|
||||||
|
|
||||||
|
def test_semantic_results_injected(self, monkeypatch):
|
||||||
|
"""Curated memories from semantic search should appear in output."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch, AsyncMock, MagicMock
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.semantic_search = AsyncMock(return_value=[
|
||||||
|
{"payload": {"text": "User: Old question?\nAssistant: Old answer."}}
|
||||||
|
])
|
||||||
|
mock_qdrant.get_recent_turns = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
result = asyncio.get_event_loop().run_until_complete(
|
||||||
|
utils_module.build_augmented_messages(
|
||||||
|
[{"role": "user", "content": "Tell me"}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
contents = [m["content"] for m in result]
|
||||||
|
assert any("Old question" in c or "Old answer" in c for c in contents)
|
||||||
Reference in New Issue
Block a user