3 Commits

Author SHA1 Message Date
Vera-AI
346f2c26fe feat: semantic search includes raw turns, deduplicate layers, fix recent turn ordering
- Layer 2 semantic search now queries both curated and raw types,
  closing the blind spot for turns past the 50-turn window pre-curation
- Layer 3 skips turns already returned by Layer 2 to avoid duplicate
  context and wasted token budget
- get_recent_turns uses Qdrant OrderBy for server-side timestamp sort
  with payload index; fallback to client-side sort if unavailable
- Bump version to 2.0.4
2026-04-01 17:43:47 -05:00
Claude Code
de7f3a78ab test: expand coverage to 92% with production-realistic fixtures
Add conftest.py with shared fixtures, new test_qdrant_service.py covering
all QdrantService methods, and expand curator/proxy/integration/utils tests
to cover debug_log, forward_to_ollama, curator.run(), load_system_prompt,
and health check edge cases.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-01 16:20:35 -05:00
Claude Code
6154e7974e fix: add debug log file locking, improve error logging, validate cloud API key
- Add portalocker file locking to debug_log() to prevent interleaved entries
- Add exc_info=True to curator _call_llm error logging for stack traces
- Add debug log message on JSON parse fallback in _parse_json_response
- Warn when cloud is enabled but API key env var is not set

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-01 16:15:56 -05:00
12 changed files with 923 additions and 33 deletions

View File

@@ -113,6 +113,13 @@ class Config:
models=cloud_data.get("models", {})
)
if config.cloud.enabled and not config.cloud.api_key:
import logging
logging.getLogger(__name__).warning(
"Cloud is enabled but API key env var '%s' is not set",
config.cloud.api_key_env
)
return config
config = Config.load()

View File

@@ -212,7 +212,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
result = response.json()
return result.get("response", "")
except Exception as e:
logger.error(f"Failed to call LLM: {e}")
logger.error(f"LLM call failed: {e}", exc_info=True)
return ""
def _parse_json_response(self, response: str) -> Optional[Dict]:
@@ -223,6 +223,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
try:
return json.loads(response)
except json.JSONDecodeError:
logger.debug("Direct JSON parse failed, trying brace extraction")
pass
try:

View File

@@ -68,7 +68,7 @@ async def lifespan(app: FastAPI):
await qdrant_service.close()
app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan)
app = FastAPI(title="Vera-AI", version="2.0.4", lifespan=lifespan)
@app.get("/")

View File

@@ -6,6 +6,7 @@ import json
import re
import logging
import os
import portalocker
from pathlib import Path
from .config import config
from .singleton import get_qdrant_service
@@ -66,7 +67,9 @@ def debug_log(category: str, message: str, data: dict = None):
entry["data"] = data
with open(log_path, "a") as f:
portalocker.lock(f, portalocker.LOCK_EX)
f.write(json.dumps(entry) + "\n")
portalocker.unlock(f)
async def handle_chat_non_streaming(body: dict):

View File

@@ -1,6 +1,6 @@
"""Qdrant service for memory storage - ASYNC VERSION."""
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType
from typing import List, Dict, Any, Optional
from datetime import datetime, timezone
import uuid
@@ -34,6 +34,15 @@ class QdrantService:
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
# Ensure payload index on timestamp for ordered scroll
try:
await self.client.create_payload_index(
collection_name=self.collection,
field_name="timestamp",
field_schema=PayloadSchemaType.KEYWORD
)
except Exception:
pass # Index may already exist
self._collection_ensured = True
async def get_embedding(self, text: str) -> List[float]:
@@ -105,20 +114,28 @@ class QdrantService:
)
return point_id
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]:
"""Semantic search for relevant turns, filtered by type."""
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated", entry_types: Optional[List[str]] = None) -> List[Dict]:
"""Semantic search for relevant turns, filtered by type(s)."""
await self._ensure_collection()
embedding = await self.get_embedding(query)
if entry_types and len(entry_types) > 1:
type_filter = Filter(
should=[FieldCondition(key="type", match=MatchValue(value=t)) for t in entry_types]
)
else:
filter_type = entry_types[0] if entry_types else entry_type
type_filter = Filter(
must=[FieldCondition(key="type", match=MatchValue(value=filter_type))]
)
results = await self.client.query_points(
collection_name=self.collection,
query=embedding,
limit=limit,
score_threshold=score_threshold,
query_filter=Filter(
must=[FieldCondition(key="type", match=MatchValue(value=entry_type))]
)
query_filter=type_filter
)
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
@@ -127,20 +144,28 @@ class QdrantService:
"""Get recent turns from Qdrant (both raw and curated)."""
await self._ensure_collection()
try:
from qdrant_client.models import OrderBy
points, _ = await self.client.scroll(
collection_name=self.collection,
limit=limit * 2,
limit=limit,
with_payload=True,
order_by=OrderBy(key="timestamp", direction="desc")
)
except Exception:
# Fallback: fetch extra points and sort client-side
points, _ = await self.client.scroll(
collection_name=self.collection,
limit=limit * 5,
with_payload=True
)
# Sort by timestamp descending
sorted_points = sorted(
points = sorted(
points,
key=lambda p: p.payload.get("timestamp", ""),
reverse=True
)
)[:limit]
return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]]
return [{"id": str(p.id), "payload": p.payload} for p in points]
async def delete_points(self, point_ids: List[str]) -> None:
"""Delete points by ID."""

View File

@@ -213,23 +213,25 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
messages.append({"role": "system", "content": system_content})
logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens")
# === LAYER 2: Semantic (curated memories) ===
# === LAYER 2: Semantic (curated + raw memories) ===
qdrant = get_qdrant_service()
semantic_results = await qdrant.semantic_search(
query=search_context if search_context else user_question,
limit=20,
score_threshold=config.semantic_score_threshold,
entry_type="curated"
entry_types=["curated", "raw"]
)
semantic_messages = []
semantic_tokens_used = 0
semantic_ids = set()
for result in semantic_results:
semantic_ids.add(result.get("id"))
payload = result.get("payload", {})
text = payload.get("text", "")
if text:
# Parse curated turn into proper user/assistant messages
# Parse curated/raw turn into proper user/assistant messages
parsed = parse_curated_turn(text)
for msg in parsed:
msg_tokens = count_tokens(msg.get("content", ""))
@@ -254,8 +256,10 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
context_messages = []
context_tokens_used = 0
# Process oldest first for chronological order
# Process oldest first for chronological order, skip duplicates from Layer 2
for turn in reversed(recent_turns):
if turn.get("id") in semantic_ids:
continue
payload = turn.get("payload", {})
text = payload.get("text", "")
entry_type = payload.get("type", "raw")

62
tests/conftest.py Normal file
View File

@@ -0,0 +1,62 @@
"""Shared test fixtures using production-realistic data."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.config import Config
@pytest.fixture
def production_config():
"""Config matching production deployment on deb8."""
config = MagicMock(spec=Config)
config.ollama_host = "http://10.0.0.10:11434"
config.qdrant_host = "http://10.0.0.22:6333"
config.qdrant_collection = "memories"
config.embedding_model = "snowflake-arctic-embed2"
config.semantic_token_budget = 25000
config.context_token_budget = 22000
config.semantic_search_turns = 2
config.semantic_score_threshold = 0.6
config.run_time = "02:00"
config.curator_model = "gpt-oss:120b"
config.debug = False
config.vector_size = 1024
config.cloud = MagicMock()
config.cloud.enabled = False
config.cloud.models = {}
config.cloud.get_cloud_model.return_value = None
return config
@pytest.fixture
def sample_qdrant_raw_payload():
"""Sample raw payload from production Qdrant."""
return {
"type": "raw",
"text": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z",
"timestamp": "2026-03-27T12:50:37.451593Z",
"role": "qa",
"content": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z"
}
@pytest.fixture
def sample_ollama_models():
"""Model list from production Ollama."""
return {
"models": [
{
"name": "snowflake-arctic-embed2:latest",
"model": "snowflake-arctic-embed2:latest",
"modified_at": "2026-02-16T16:43:44Z",
"size": 1160296718,
"details": {"family": "bert", "parameter_size": "566.70M", "quantization_level": "F16"}
},
{
"name": "gpt-oss:120b",
"model": "gpt-oss:120b",
"modified_at": "2026-03-11T12:45:48Z",
"size": 65369818941,
"details": {"family": "gptoss", "parameter_size": "116.8B", "quantization_level": "MXFP4"}
}
]
}

View File

@@ -4,7 +4,7 @@ import json
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, AsyncMock, patch
def make_curator():
@@ -197,3 +197,294 @@ class TestAppendRuleToFile:
target = prompts_dir / "newfile.md"
assert target.exists()
assert "New rule here." in target.read_text()
class TestFormatExistingMemories:
"""Tests for Curator._format_existing_memories."""
def test_empty_list_returns_no_memories_message(self):
"""Empty list returns a 'no memories' message."""
curator, _ = make_curator()
result = curator._format_existing_memories([])
assert "No existing curated memories" in result
def test_single_memory_formatted(self):
"""Single memory text is included in output."""
curator, _ = make_curator()
memories = [{"text": "User: hello\nAssistant: hi there"}]
result = curator._format_existing_memories(memories)
assert "hello" in result
assert "hi there" in result
def test_limits_to_last_20(self):
"""Only last 20 memories are included."""
curator, _ = make_curator()
memories = [{"text": f"memory {i}"} for i in range(30)]
result = curator._format_existing_memories(memories)
# Should contain memory 10-29 (last 20), not memory 0-9
assert "memory 29" in result
assert "memory 10" in result
class TestCallLlm:
"""Tests for Curator._call_llm."""
@pytest.mark.asyncio
async def test_call_llm_returns_response(self):
"""_call_llm returns the response text from Ollama."""
curator, _ = make_curator()
mock_resp = MagicMock()
mock_resp.json.return_value = {"response": "some LLM output"}
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.post = AsyncMock(return_value=mock_resp)
with patch("httpx.AsyncClient", return_value=mock_client):
result = await curator._call_llm("test prompt")
assert result == "some LLM output"
call_kwargs = mock_client.post.call_args
assert "test-model" in call_kwargs[1]["json"]["model"]
@pytest.mark.asyncio
async def test_call_llm_returns_empty_on_error(self):
"""_call_llm returns empty string when Ollama errors."""
curator, _ = make_curator()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.post = AsyncMock(side_effect=Exception("connection refused"))
with patch("httpx.AsyncClient", return_value=mock_client):
result = await curator._call_llm("test prompt")
assert result == ""
class TestCuratorRun:
"""Tests for Curator.run() method."""
@pytest.mark.asyncio
async def test_run_no_raw_memories_exits_early(self):
"""run() exits early when no raw memories found."""
curator, mock_qdrant = make_curator()
# Mock scroll to return no points
mock_qdrant.client = AsyncMock()
mock_qdrant.client.scroll = AsyncMock(return_value=([], None))
mock_qdrant.collection = "memories"
await curator.run()
# Should not call LLM since there are no raw memories
# If it got here without error, that's success
@pytest.mark.asyncio
async def test_run_processes_raw_memories(self):
"""run() processes raw memories and stores curated results."""
curator, mock_qdrant = make_curator()
# Create mock points
mock_point = MagicMock()
mock_point.id = "point-1"
mock_point.payload = {
"type": "raw",
"text": "User: hello\nAssistant: hi",
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
}
mock_qdrant.client = AsyncMock()
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
mock_qdrant.collection = "memories"
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
mock_qdrant.delete_points = AsyncMock()
llm_response = json.dumps({
"new_curated_turns": [{"content": "User: hello\nAssistant: hi"}],
"permanent_rules": [],
"deletions": [],
"summary": "Curated one turn"
})
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)):
await curator.run()
mock_qdrant.store_turn.assert_called_once()
mock_qdrant.delete_points.assert_called()
@pytest.mark.asyncio
async def test_run_monthly_mode_on_day_01(self):
"""run() uses monthly mode on day 01, processing all raw memories."""
curator, mock_qdrant = make_curator()
# Create a mock point with an old timestamp (outside 24h window)
old_ts = (datetime.now(timezone.utc) - timedelta(hours=72)).isoformat().replace("+00:00", "Z")
mock_point = MagicMock()
mock_point.id = "old-point"
mock_point.payload = {
"type": "raw",
"text": "User: old question\nAssistant: old answer",
"timestamp": old_ts,
}
mock_qdrant.client = AsyncMock()
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
mock_qdrant.collection = "memories"
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
mock_qdrant.delete_points = AsyncMock()
llm_response = json.dumps({
"new_curated_turns": [],
"permanent_rules": [],
"deletions": [],
"summary": "Nothing to curate"
})
# Mock day 01
mock_now = datetime(2026, 4, 1, 2, 0, 0, tzinfo=timezone.utc)
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \
patch("app.curator.datetime") as mock_dt:
mock_dt.now.return_value = mock_now
mock_dt.fromisoformat = datetime.fromisoformat
mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw)
await curator.run()
# In monthly mode, even old memories are processed, so LLM should be called
# and delete_points should be called for the raw memory
mock_qdrant.delete_points.assert_called()
@pytest.mark.asyncio
async def test_run_handles_permanent_rules(self):
"""run() appends permanent rules to prompt files."""
curator, mock_qdrant = make_curator()
mock_point = MagicMock()
mock_point.id = "point-1"
mock_point.payload = {
"type": "raw",
"text": "User: remember this\nAssistant: ok",
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
}
mock_qdrant.client = AsyncMock()
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
mock_qdrant.collection = "memories"
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
mock_qdrant.delete_points = AsyncMock()
llm_response = json.dumps({
"new_curated_turns": [],
"permanent_rules": [{"rule": "Always be concise.", "target_file": "systemprompt.md"}],
"deletions": [],
"summary": "Added a rule"
})
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \
patch.object(curator, "_append_rule_to_file", AsyncMock()) as mock_append:
await curator.run()
mock_append.assert_called_once_with("systemprompt.md", "Always be concise.")
@pytest.mark.asyncio
async def test_run_handles_deletions(self):
"""run() deletes specified point IDs when they exist in the database."""
curator, mock_qdrant = make_curator()
mock_point = MagicMock()
mock_point.id = "point-1"
mock_point.payload = {
"type": "raw",
"text": "User: delete me\nAssistant: ok",
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
}
mock_qdrant.client = AsyncMock()
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
mock_qdrant.collection = "memories"
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
mock_qdrant.delete_points = AsyncMock()
llm_response = json.dumps({
"new_curated_turns": [],
"permanent_rules": [],
"deletions": ["point-1"],
"summary": "Deleted one"
})
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)):
await curator.run()
# delete_points should be called at least twice: once for valid deletions, once for processed raw
assert mock_qdrant.delete_points.call_count >= 1
@pytest.mark.asyncio
async def test_run_handles_llm_parse_failure(self):
"""run() handles LLM returning unparseable response gracefully."""
curator, mock_qdrant = make_curator()
mock_point = MagicMock()
mock_point.id = "point-1"
mock_point.payload = {
"type": "raw",
"text": "User: test\nAssistant: ok",
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
}
mock_qdrant.client = AsyncMock()
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
mock_qdrant.collection = "memories"
with patch.object(curator, "_call_llm", AsyncMock(return_value="not json at all!!!")):
# Should not raise - just return early
await curator.run()
# store_turn should NOT be called since parsing failed
mock_qdrant.store_turn = AsyncMock()
mock_qdrant.store_turn.assert_not_called()
class TestLoadCuratorPrompt:
"""Tests for load_curator_prompt function."""
def test_loads_from_prompts_dir(self, tmp_path):
"""load_curator_prompt loads from PROMPTS_DIR."""
import app.curator as curator_module
prompts_dir = tmp_path / "prompts"
prompts_dir.mkdir()
(prompts_dir / "curator_prompt.md").write_text("Test curator prompt")
with patch.object(curator_module, "PROMPTS_DIR", prompts_dir):
from app.curator import load_curator_prompt
result = load_curator_prompt()
assert result == "Test curator prompt"
def test_falls_back_to_static_dir(self, tmp_path):
"""load_curator_prompt falls back to STATIC_DIR."""
import app.curator as curator_module
prompts_dir = tmp_path / "prompts" # does not exist
static_dir = tmp_path / "static"
static_dir.mkdir()
(static_dir / "curator_prompt.md").write_text("Static prompt")
with patch.object(curator_module, "PROMPTS_DIR", prompts_dir), \
patch.object(curator_module, "STATIC_DIR", static_dir):
from app.curator import load_curator_prompt
result = load_curator_prompt()
assert result == "Static prompt"
def test_raises_when_not_found(self, tmp_path):
"""load_curator_prompt raises FileNotFoundError when file missing."""
import app.curator as curator_module
with patch.object(curator_module, "PROMPTS_DIR", tmp_path / "nope"), \
patch.object(curator_module, "STATIC_DIR", tmp_path / "also_nope"):
from app.curator import load_curator_prompt
with pytest.raises(FileNotFoundError):
load_curator_prompt()

View File

@@ -349,3 +349,83 @@ class TestApiChatStreaming:
# Response body should contain both chunks concatenated
body_text = resp.text
assert "Hello" in body_text or len(body_text) > 0
# ---------------------------------------------------------------------------
# Health check edge cases
# ---------------------------------------------------------------------------
class TestHealthCheckEdgeCases:
def test_health_ollama_timeout(self, app_with_mocks):
"""GET / handles Ollama timeout gracefully."""
import httpx
from fastapi.testclient import TestClient
vera_app, _ = app_with_mocks
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
mock_client_instance.get = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
with patch("httpx.AsyncClient", return_value=mock_client_instance):
with TestClient(vera_app, raise_server_exceptions=True) as client:
resp = client.get("/")
assert resp.status_code == 200
assert resp.json()["ollama"] == "unreachable"
# ---------------------------------------------------------------------------
# POST /curator/run
# ---------------------------------------------------------------------------
class TestTriggerCurator:
def test_trigger_curator_endpoint(self, app_with_mocks):
"""POST /curator/run triggers curation and returns status."""
from fastapi.testclient import TestClient
import app.main as main_module
vera_app, _ = app_with_mocks
mock_curator = MagicMock()
mock_curator.run = AsyncMock()
with patch.object(main_module, "curator", mock_curator):
with TestClient(vera_app) as client:
resp = client.post("/curator/run")
assert resp.status_code == 200
assert resp.json()["status"] == "curation completed"
mock_curator.run.assert_called_once()
# ---------------------------------------------------------------------------
# Proxy catch-all
# ---------------------------------------------------------------------------
class TestProxyAll:
def test_non_chat_api_proxied(self, app_with_mocks):
"""Non-chat API paths are proxied to Ollama."""
from fastapi.testclient import TestClient
vera_app, _ = app_with_mocks
async def fake_aiter_bytes():
yield b'{"status": "ok"}'
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"content-type": "application/json"}
mock_resp.aiter_bytes = fake_aiter_bytes
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
mock_client_instance.request = AsyncMock(return_value=mock_resp)
with patch("httpx.AsyncClient", return_value=mock_client_instance):
with TestClient(vera_app) as client:
resp = client.get("/api/show")
assert resp.status_code == 200

View File

@@ -219,3 +219,94 @@ class TestHandleChatNonStreaming:
# The wrapper should be stripped
assert "Memory context" not in stored_question
assert "What is the answer?" in stored_question
class TestDebugLog:
"""Tests for debug_log function."""
def test_debug_log_writes_json_when_enabled(self, tmp_path):
"""Debug log appends valid JSON line to file when debug=True."""
import json
from unittest.mock import patch, MagicMock
mock_config = MagicMock()
mock_config.debug = True
with patch("app.proxy_handler.config", mock_config), \
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
from app.proxy_handler import debug_log
debug_log("test_cat", "test message", {"key": "value"})
log_files = list(tmp_path.glob("debug_*.log"))
assert len(log_files) == 1
content = log_files[0].read_text().strip()
entry = json.loads(content)
assert entry["category"] == "test_cat"
assert entry["message"] == "test message"
assert entry["data"]["key"] == "value"
def test_debug_log_skips_when_disabled(self, tmp_path):
"""Debug log does nothing when debug=False."""
from unittest.mock import patch, MagicMock
mock_config = MagicMock()
mock_config.debug = False
with patch("app.proxy_handler.config", mock_config), \
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
from app.proxy_handler import debug_log
debug_log("test_cat", "test message")
log_files = list(tmp_path.glob("debug_*.log"))
assert len(log_files) == 0
def test_debug_log_without_data(self, tmp_path):
"""Debug log works without optional data parameter."""
import json
from unittest.mock import patch, MagicMock
mock_config = MagicMock()
mock_config.debug = True
with patch("app.proxy_handler.config", mock_config), \
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
from app.proxy_handler import debug_log
debug_log("simple_cat", "no data here")
log_files = list(tmp_path.glob("debug_*.log"))
assert len(log_files) == 1
entry = json.loads(log_files[0].read_text().strip())
assert "data" not in entry
assert entry["category"] == "simple_cat"
class TestForwardToOllama:
"""Tests for forward_to_ollama function."""
@pytest.mark.asyncio
async def test_forwards_request_to_ollama(self):
"""forward_to_ollama proxies request to Ollama host."""
from app.proxy_handler import forward_to_ollama
from unittest.mock import patch, AsyncMock, MagicMock
mock_request = AsyncMock()
mock_request.body = AsyncMock(return_value=b'{"model": "llama3"}')
mock_request.method = "POST"
mock_request.headers = {"content-type": "application/json", "content-length": "20"}
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.request = AsyncMock(return_value=mock_resp)
with patch("httpx.AsyncClient", return_value=mock_client):
result = await forward_to_ollama(mock_request, "/api/show")
assert result == mock_resp
mock_client.request.assert_called_once()
call_kwargs = mock_client.request.call_args
assert call_kwargs[1]["method"] == "POST"
assert "/api/show" in call_kwargs[1]["url"]

View File

@@ -0,0 +1,256 @@
"""Tests for QdrantService — all Qdrant and Ollama calls are mocked."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
def make_qdrant_service():
"""Create a QdrantService with mocked AsyncQdrantClient."""
with patch("app.qdrant_service.AsyncQdrantClient") as MockClient:
mock_client = AsyncMock()
MockClient.return_value = mock_client
from app.qdrant_service import QdrantService
svc = QdrantService(
host="http://localhost:6333",
collection="test_memories",
embedding_model="snowflake-arctic-embed2",
vector_size=1024,
ollama_host="http://localhost:11434",
)
return svc, mock_client
class TestEnsureCollection:
"""Tests for _ensure_collection."""
@pytest.mark.asyncio
async def test_creates_collection_when_missing(self):
"""Creates collection if it does not exist."""
svc, mock_client = make_qdrant_service()
mock_client.get_collection = AsyncMock(side_effect=Exception("not found"))
mock_client.create_collection = AsyncMock()
await svc._ensure_collection()
mock_client.create_collection.assert_called_once()
assert svc._collection_ensured is True
@pytest.mark.asyncio
async def test_skips_if_collection_exists(self):
"""Does not create collection if it already exists."""
svc, mock_client = make_qdrant_service()
mock_client.get_collection = AsyncMock(return_value=MagicMock())
mock_client.create_collection = AsyncMock()
await svc._ensure_collection()
mock_client.create_collection.assert_not_called()
assert svc._collection_ensured is True
@pytest.mark.asyncio
async def test_skips_if_already_ensured(self):
"""Skips entirely if _collection_ensured is True."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.get_collection = AsyncMock()
await svc._ensure_collection()
mock_client.get_collection.assert_not_called()
class TestGetEmbedding:
"""Tests for get_embedding."""
@pytest.mark.asyncio
async def test_returns_embedding_vector(self):
"""Returns embedding from Ollama response."""
svc, _ = make_qdrant_service()
fake_embedding = [0.1] * 1024
mock_resp = MagicMock()
mock_resp.json.return_value = {"embedding": fake_embedding}
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.post = AsyncMock(return_value=mock_resp)
with patch("httpx.AsyncClient", return_value=mock_client):
result = await svc.get_embedding("test text")
assert result == fake_embedding
assert len(result) == 1024
class TestStoreTurn:
"""Tests for store_turn."""
@pytest.mark.asyncio
async def test_stores_raw_user_turn(self):
"""Stores a user turn with proper payload."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
point_id = await svc.store_turn(role="user", content="hello world")
assert isinstance(point_id, str)
mock_client.upsert.assert_called_once()
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["type"] == "raw"
assert point.payload["role"] == "user"
assert "User: hello world" in point.payload["text"]
@pytest.mark.asyncio
async def test_stores_curated_turn(self):
"""Stores a curated turn without role prefix in text."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
point_id = await svc.store_turn(
role="curated",
content="User: q\nAssistant: a",
entry_type="curated"
)
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["type"] == "curated"
# Curated text should be the content directly, not prefixed
assert point.payload["text"] == "User: q\nAssistant: a"
@pytest.mark.asyncio
async def test_stores_with_topic_and_metadata(self):
"""Stores turn with optional topic and metadata."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
await svc.store_turn(
role="assistant",
content="some response",
topic="python",
metadata={"source": "test"}
)
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["topic"] == "python"
assert point.payload["source"] == "test"
class TestStoreQaTurn:
"""Tests for store_qa_turn."""
@pytest.mark.asyncio
async def test_stores_qa_turn(self):
"""Stores a complete Q&A turn."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_client.upsert = AsyncMock()
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
point_id = await svc.store_qa_turn("What is Python?", "A programming language.")
assert isinstance(point_id, str)
call_args = mock_client.upsert.call_args
point = call_args[1]["points"][0]
assert point.payload["type"] == "raw"
assert point.payload["role"] == "qa"
assert "What is Python?" in point.payload["text"]
assert "A programming language." in point.payload["text"]
class TestSemanticSearch:
"""Tests for semantic_search."""
@pytest.mark.asyncio
async def test_returns_matching_results(self):
"""Returns formatted search results."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_point = MagicMock()
mock_point.id = "result-1"
mock_point.score = 0.85
mock_point.payload = {"text": "User: hello\nAssistant: hi", "type": "curated"}
mock_query_result = MagicMock()
mock_query_result.points = [mock_point]
mock_client.query_points = AsyncMock(return_value=mock_query_result)
fake_embedding = [0.1] * 1024
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
results = await svc.semantic_search("hello", limit=10, score_threshold=0.6)
assert len(results) == 1
assert results[0]["id"] == "result-1"
assert results[0]["score"] == 0.85
assert results[0]["payload"]["type"] == "curated"
class TestGetRecentTurns:
"""Tests for get_recent_turns."""
@pytest.mark.asyncio
async def test_returns_sorted_recent_turns(self):
"""Returns turns sorted by timestamp descending."""
svc, mock_client = make_qdrant_service()
svc._collection_ensured = True
mock_point1 = MagicMock()
mock_point1.id = "old"
mock_point1.payload = {"timestamp": "2026-01-01T00:00:00Z", "text": "old turn"}
mock_point2 = MagicMock()
mock_point2.id = "new"
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
# OrderBy returns server-sorted results (newest first)
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
results = await svc.get_recent_turns(limit=2)
assert len(results) == 2
# Newest first (server-sorted via OrderBy)
assert results[0]["id"] == "new"
assert results[1]["id"] == "old"
class TestDeletePoints:
"""Tests for delete_points."""
@pytest.mark.asyncio
async def test_deletes_by_ids(self):
"""Deletes points by their IDs."""
svc, mock_client = make_qdrant_service()
mock_client.delete = AsyncMock()
await svc.delete_points(["id1", "id2"])
mock_client.delete.assert_called_once()
class TestClose:
"""Tests for close."""
@pytest.mark.asyncio
async def test_closes_client(self):
"""Closes the async Qdrant client."""
svc, mock_client = make_qdrant_service()
mock_client.close = AsyncMock()
await svc.close()
mock_client.close.assert_called_once()

View File

@@ -1,7 +1,7 @@
"""Tests for utility functions."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn, build_augmented_messages
from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn, build_augmented_messages, count_messages_tokens
class TestCountTokens:
@@ -86,6 +86,76 @@ Assistant: Yes, very popular."""
assert "Line 3" in result[0]["content"]
class TestCountMessagesTokens:
"""Tests for count_messages_tokens function."""
def test_empty_list(self):
"""Empty message list returns 0."""
assert count_messages_tokens([]) == 0
def test_single_message(self):
"""Single message counts tokens of its content."""
msgs = [{"role": "user", "content": "Hello world"}]
result = count_messages_tokens(msgs)
assert result > 0
def test_multiple_messages(self):
"""Multiple messages sum up their token counts."""
msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there, how can I help you today?"},
]
result = count_messages_tokens(msgs)
assert result > count_messages_tokens([msgs[0]])
def test_message_without_content(self):
"""Message without content field contributes 0 tokens."""
msgs = [{"role": "system"}]
assert count_messages_tokens(msgs) == 0
class TestLoadSystemPrompt:
"""Tests for load_system_prompt function."""
def test_loads_from_prompts_dir(self, tmp_path):
"""Loads systemprompt.md from PROMPTS_DIR."""
import app.utils as utils_module
prompts_dir = tmp_path / "prompts"
prompts_dir.mkdir()
(prompts_dir / "systemprompt.md").write_text("You are Vera.")
with patch.object(utils_module, "PROMPTS_DIR", prompts_dir):
result = utils_module.load_system_prompt()
assert result == "You are Vera."
def test_falls_back_to_static_dir(self, tmp_path):
"""Falls back to STATIC_DIR when PROMPTS_DIR has no file."""
import app.utils as utils_module
prompts_dir = tmp_path / "no_prompts" # does not exist
static_dir = tmp_path / "static"
static_dir.mkdir()
(static_dir / "systemprompt.md").write_text("Static Vera.")
with patch.object(utils_module, "PROMPTS_DIR", prompts_dir), \
patch.object(utils_module, "STATIC_DIR", static_dir):
result = utils_module.load_system_prompt()
assert result == "Static Vera."
def test_returns_empty_when_not_found(self, tmp_path):
"""Returns empty string when systemprompt.md not found anywhere."""
import app.utils as utils_module
with patch.object(utils_module, "PROMPTS_DIR", tmp_path / "nope"), \
patch.object(utils_module, "STATIC_DIR", tmp_path / "also_nope"):
result = utils_module.load_system_prompt()
assert result == ""
class TestFilterMemoriesByTime:
"""Tests for filter_memories_by_time function."""