Compare commits
15 Commits
50874eeae9
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b44ff16ac | ||
|
|
346f2c26fe | ||
|
|
de7f3a78ab | ||
|
|
6154e7974e | ||
|
|
cbe12f0ebd | ||
|
|
9fa5d08ce0 | ||
|
|
90dd87edeb | ||
|
|
2801a63b11 | ||
|
|
355986a59f | ||
|
|
600f9deec1 | ||
|
|
9774875173 | ||
|
|
bfd0221928 | ||
|
|
abfcc91eb3 | ||
|
|
34304a79e0 | ||
|
|
c78b3f2bb6 |
69
.claude/skills/ssh/SKILL.md
Normal file
69
.claude/skills/ssh/SKILL.md
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: ssh
|
||||
description: SSH into remote servers and execute commands. Use for remote operations, file transfers, and server management.
|
||||
allowed-tools: Bash(ssh*), Bash(scp*), Bash(rsync*), Bash(sshpass*), Read, Write
|
||||
argument-hint: [host-alias]
|
||||
---
|
||||
|
||||
## SSH Connections
|
||||
|
||||
| Alias | Host | User | Password | Hostname | Purpose |
|
||||
|-------|------|------|----------|----------|---------|
|
||||
| `deb9` | `10.0.0.48` | `n8n` | `passw0rd` | epyc-deb9 | vera-ai source project |
|
||||
| `deb8` | `10.0.0.46` | `n8n` | `passw0rd` | epyc-deb8 | vera-ai Docker runtime |
|
||||
|
||||
## Connection Commands
|
||||
|
||||
**Interactive SSH:**
|
||||
```bash
|
||||
sshpass -p 'passw0rd' ssh -o StrictHostKeyChecking=no n8n@10.0.0.48
|
||||
sshpass -p 'passw0rd' ssh -o StrictHostKeyChecking=no n8n@10.0.0.46
|
||||
```
|
||||
|
||||
**Run single command:**
|
||||
```bash
|
||||
sshpass -p 'passw0rd' ssh -o StrictHostKeyChecking=no n8n@10.0.0.48 "command"
|
||||
sshpass -p 'passw0rd' ssh -o StrictHostKeyChecking=no n8n@10.0.0.46 "command"
|
||||
```
|
||||
|
||||
**Copy file to server:**
|
||||
```bash
|
||||
sshpass -p 'passw0rd' scp -o StrictHostKeyChecking=no local_file n8n@10.0.0.48:/remote/path
|
||||
sshpass -p 'passw0rd' scp -o StrictHostKeyChecking=no local_file n8n@10.0.0.46:/remote/path
|
||||
```
|
||||
|
||||
**Copy file from server:**
|
||||
```bash
|
||||
sshpass -p 'passw0rd' scp -o StrictHostKeyChecking=no n8n@10.0.0.48:/remote/path local_file
|
||||
sshpass -p 'passw0rd' scp -o StrictHostKeyChecking=no n8n@10.0.0.46:/remote/path local_file
|
||||
```
|
||||
|
||||
**Sync directory to server:**
|
||||
```bash
|
||||
sshpass -p 'passw0rd' rsync -avz -e "ssh -o StrictHostKeyChecking=no" local_dir/ n8n@10.0.0.48:/remote/path/
|
||||
sshpass -p 'passw0rd' rsync -avz -e "ssh -o StrictHostKeyChecking=no" local_dir/ n8n@10.0.0.46:/remote/path/
|
||||
```
|
||||
|
||||
**Sync directory from server:**
|
||||
```bash
|
||||
sshpass -p 'passw0rd' rsync -avz -e "ssh -o StrictHostKeyChecking=no" n8n@10.0.0.48:/remote/path/ local_dir/
|
||||
sshpass -p 'passw0rd' rsync -avz -e "ssh -o StrictHostKeyChecking=no" n8n@10.0.0.46:/remote/path/ local_dir/
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Uses `sshpass` to handle password authentication non-interactively
|
||||
- `-o StrictHostKeyChecking=no` prevents host key prompts (useful for automation)
|
||||
- For frequent connections, consider setting up SSH key authentication instead of password
|
||||
|
||||
## SSH Config (Optional)
|
||||
|
||||
To simplify connections, add to `~/.ssh/config`:
|
||||
|
||||
```
|
||||
Host n8n-server
|
||||
HostName 10.0.0.48
|
||||
User n8n
|
||||
```
|
||||
|
||||
Then connect with just `ssh n8n-server` (still needs password or key).
|
||||
31
.env.example
31
.env.example
@@ -1,31 +0,0 @@
|
||||
# Vera-AI Environment Configuration
|
||||
# Copy this file to .env and customize for your deployment
|
||||
|
||||
# =============================================================================
|
||||
# User/Group Configuration
|
||||
# =============================================================================
|
||||
# UID and GID for the container user (must match host user for volume permissions)
|
||||
# Run: id -u and id -g on your host to get these values
|
||||
APP_UID=1000
|
||||
APP_GID=1000
|
||||
|
||||
# =============================================================================
|
||||
# Timezone Configuration
|
||||
# =============================================================================
|
||||
# Timezone for the container (affects scheduler times)
|
||||
# Common values: UTC, America/New_York, America/Chicago, America/Los_Angeles, Europe/London
|
||||
TZ=America/Chicago
|
||||
|
||||
# =============================================================================
|
||||
# API Keys (Optional)
|
||||
# =============================================================================
|
||||
# OpenRouter API key for cloud model routing
|
||||
# OPENROUTER_API_KEY=your_api_key_here
|
||||
|
||||
# =============================================================================
|
||||
# Vera-AI Configuration Paths (Optional)
|
||||
# =============================================================================
|
||||
# These can be overridden via environment variables
|
||||
# VERA_CONFIG_DIR=/app/config
|
||||
# VERA_PROMPTS_DIR=/app/prompts
|
||||
# VERA_STATIC_DIR=/app/static
|
||||
187
CLAUDE.md
Normal file
187
CLAUDE.md
Normal file
@@ -0,0 +1,187 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Infrastructure
|
||||
|
||||
| Role | Host | Access |
|
||||
|------|------|--------|
|
||||
| Source (deb9) | 10.0.0.48 | `ssh deb9` — `/home/n8n/vera-ai/` |
|
||||
| Production (deb8) | 10.0.0.46 | `ssh deb8` — runs vera-ai in Docker |
|
||||
| Gitea | 10.0.0.61:3000 | `SpeedyFoxAi/vera-ai-v2`, HTTPS only (SSH disabled) |
|
||||
|
||||
User `n8n` on deb8/deb9. SSH key `~/.ssh/vera-ai`. Gitea credentials in `~/.netrc`.
|
||||
|
||||
## Git Workflow
|
||||
|
||||
Three locations — all point to `origin` on Gitea:
|
||||
|
||||
```
|
||||
local (/home/adm1n/claude/vera-ai) ←→ Gitea (10.0.0.61:3000) ←→ deb9 (/home/n8n/vera-ai)
|
||||
↓ ↓
|
||||
github/gitlab deb8 (scp files + docker build)
|
||||
(mirrors)
|
||||
```
|
||||
|
||||
```bash
|
||||
# Edit on deb9, commit, push
|
||||
ssh deb9
|
||||
cd /home/n8n/vera-ai
|
||||
git pull origin main # sync first
|
||||
git add -p && git commit -m "..."
|
||||
git push origin main
|
||||
|
||||
# Pull to local working copy
|
||||
cd /home/adm1n/claude/vera-ai
|
||||
git pull origin main
|
||||
|
||||
# Deploy to production (deb8 has no git repo — scp files then build)
|
||||
scp app/*.py n8n@10.0.0.46:/home/n8n/vera-ai/app/
|
||||
ssh deb8 'cd /home/n8n/vera-ai && docker compose build && docker compose up -d'
|
||||
```
|
||||
|
||||
## Publishing (Docker Hub + Git Mirrors)
|
||||
|
||||
Image: `mdkrushr/vera-ai` on Docker Hub. Build and push from deb8:
|
||||
|
||||
```bash
|
||||
ssh deb8
|
||||
cd /home/n8n/vera-ai
|
||||
docker build -t mdkrushr/vera-ai:2.0.4 -t mdkrushr/vera-ai:latest .
|
||||
docker push mdkrushr/vera-ai:2.0.4
|
||||
docker push mdkrushr/vera-ai:latest
|
||||
```
|
||||
|
||||
The local repo has two mirror remotes for public distribution. After committing and pushing to `origin` (Gitea), mirror with:
|
||||
|
||||
```bash
|
||||
git push github main --tags
|
||||
git push gitlab main --tags
|
||||
```
|
||||
|
||||
| Remote | URL |
|
||||
|--------|-----|
|
||||
| `origin` | `10.0.0.61:3000/SpeedyFoxAi/vera-ai-v2` (Gitea, primary) |
|
||||
| `github` | `github.com/speedyfoxai/vera-ai` |
|
||||
| `gitlab` | `gitlab.com/mdkrush/vera-ai` |
|
||||
|
||||
## Build & Run (deb8, production)
|
||||
|
||||
```bash
|
||||
ssh deb8
|
||||
cd /home/n8n/vera-ai
|
||||
docker compose build
|
||||
docker compose up -d
|
||||
docker logs vera-ai --tail 30
|
||||
curl http://localhost:11434/ # health check
|
||||
curl -X POST http://localhost:11434/curator/run # trigger curation
|
||||
```
|
||||
|
||||
## Tests (deb9, source)
|
||||
|
||||
```bash
|
||||
ssh deb9
|
||||
cd /home/n8n/vera-ai
|
||||
python3 -m pytest tests/ # all tests
|
||||
python3 -m pytest tests/test_utils.py # single file
|
||||
python3 -m pytest tests/test_utils.py::TestParseCuratedTurn::test_single_turn # single test
|
||||
python3 -m pytest tests/ --cov=app --cov-report=term-missing # with coverage
|
||||
```
|
||||
|
||||
Tests are unit-only — no live Qdrant/Ollama required. `pytest.ini` sets `asyncio_mode=auto`. Shared fixtures with production-realistic data in `tests/conftest.py`.
|
||||
|
||||
Test files and what they cover:
|
||||
|
||||
| File | Covers |
|
||||
|------|--------|
|
||||
| `tests/test_utils.py` | Token counting, truncation, memory filtering/merging, `parse_curated_turn`, `load_system_prompt`, `build_augmented_messages` |
|
||||
| `tests/test_config.py` | Config defaults, TOML loading, `CloudConfig`, env var overrides |
|
||||
| `tests/test_curator.py` | JSON parsing, `_is_recent`, `_format_raw_turns`, `_format_existing_memories`, `_call_llm`, `_append_rule_to_file`, `load_curator_prompt`, full `run()` scenarios |
|
||||
| `tests/test_proxy_handler.py` | `clean_message_content`, `handle_chat_non_streaming`, `debug_log`, `forward_to_ollama` |
|
||||
| `tests/test_integration.py` | FastAPI health check, `/api/tags` (with cloud models), `/api/chat` round-trips (streaming + non-streaming), curator trigger, proxy passthrough |
|
||||
| `tests/test_qdrant_service.py` | `_ensure_collection`, `get_embedding`, `store_turn`, `store_qa_turn`, `semantic_search`, `get_recent_turns`, `delete_points`, `close` |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Client → Vera-AI :11434 → Ollama :11434
|
||||
↓↑
|
||||
Qdrant :6333
|
||||
```
|
||||
|
||||
Vera-AI is a FastAPI proxy. Every `/api/chat` request is intercepted, augmented with memory context, forwarded to Ollama, and the response Q&A is stored back in Qdrant.
|
||||
|
||||
### 4-Layer Context System (`app/utils.py:build_augmented_messages`)
|
||||
|
||||
Each chat request builds an augmented message list in this order:
|
||||
|
||||
1. **System** — caller's system prompt passed through; `prompts/systemprompt.md` appended if non-empty (if empty, caller's prompt passes through unchanged; if no caller system prompt, vera's prompt used alone)
|
||||
2. **Semantic** — curated AND raw Q&A pairs from Qdrant matching the query (score ≥ `semantic_score_threshold`, up to `semantic_token_budget` tokens). Searches both types to avoid a blind spot where raw turns fall off the recent window before curation runs.
|
||||
3. **Recent context** — last 50 turns from Qdrant (server-sorted by timestamp via payload index), oldest first, up to `context_token_budget` tokens. Deduplicates against Layer 2 results to avoid wasting token budget.
|
||||
4. **Current** — the incoming messages (non-system) passed through unchanged
|
||||
|
||||
The system prompt is **never truncated**. Semantic and context layers are budget-limited and drop excess entries silently.
|
||||
|
||||
### Memory Types in Qdrant
|
||||
|
||||
| Type | When created | Retention |
|
||||
|------|-------------|-----------|
|
||||
| `raw` | After each chat turn | Until curation runs |
|
||||
| `curated` | After curator processes `raw` | Permanent |
|
||||
|
||||
Payload format: `{type, text, timestamp, role, content}`. Curated entries use `role="curated"` with text formatted as `User: ...\nAssistant: ...\nTimestamp: ...`, which `parse_curated_turn()` deserializes back into proper message role pairs at retrieval time.
|
||||
|
||||
### Curator (`app/curator.py`)
|
||||
|
||||
Scheduled via APScheduler at `config.run_time` (default 02:00). Automatically detects day 01 of month for monthly mode (processes ALL raw) vs. daily mode (last 24h only). Sends raw memories to `curator_model` LLM with `prompts/curator_prompt.md`, expects JSON response:
|
||||
|
||||
```json
|
||||
{
|
||||
"new_curated_turns": [{"content": "User: ...\nAssistant: ..."}],
|
||||
"permanent_rules": [{"rule": "...", "target_file": "systemprompt.md"}],
|
||||
"deletions": ["uuid1", "uuid2"],
|
||||
"summary": "..."
|
||||
}
|
||||
```
|
||||
|
||||
`permanent_rules` are appended to the named file in `prompts/`. After curation, all processed raw entries are deleted.
|
||||
|
||||
### Cloud Model Routing
|
||||
|
||||
Optional `[cloud]` section in `config.toml` routes specific model names to an OpenRouter-compatible API instead of Ollama. Cloud models are injected into `/api/tags` so clients see them alongside local models.
|
||||
|
||||
```toml
|
||||
[cloud]
|
||||
enabled = true
|
||||
api_base = "https://openrouter.ai/api/v1"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
[cloud.models]
|
||||
"gpt-oss:120b" = "openai/gpt-4o"
|
||||
```
|
||||
|
||||
### Key Implementation Details
|
||||
|
||||
- **Config loading** uses stdlib `tomllib` (read-only, Python 3.11+). No third-party TOML dependency.
|
||||
- **QdrantService singleton** lives in `app/singleton.py`. All modules import from there — `app/utils.py` re-exports via `from .singleton import get_qdrant_service`.
|
||||
- **Datetime handling** uses `datetime.now(timezone.utc)` throughout. No `utcnow()` calls. Stored timestamps are naive UTC with "Z" suffix; comparison code strips tzinfo for naive-vs-naive matching.
|
||||
- **Debug logging** in `proxy_handler.py` uses `portalocker` for file locking under concurrent requests. Controlled by `config.debug`.
|
||||
|
||||
## Configuration
|
||||
|
||||
All settings in `config/config.toml`. Key tuning knobs:
|
||||
|
||||
- `semantic_token_budget` / `context_token_budget` — controls how much memory gets injected
|
||||
- `semantic_score_threshold` — lower = more (but less relevant) memories returned
|
||||
- `curator_model` — model used for daily curation (needs strong reasoning)
|
||||
- `debug = true` — enables per-request JSON logs written to `logs/debug_YYYY-MM-DD.log`
|
||||
|
||||
Environment variable overrides: `VERA_CONFIG_DIR`, `VERA_PROMPTS_DIR`, `VERA_LOG_DIR`.
|
||||
|
||||
## Related Services
|
||||
|
||||
| Service | Host | Port |
|
||||
|---------|------|------|
|
||||
| Ollama | 10.0.0.10 | 11434 |
|
||||
| Qdrant | 10.0.0.22 | 6333 |
|
||||
|
||||
Qdrant collections: `memories` (default), `vera_memories` (alternative), `python_kb` (reference patterns).
|
||||
38
Dockerfile
38
Dockerfile
@@ -4,15 +4,6 @@
|
||||
# Build arguments:
|
||||
# APP_UID: User ID for appuser (default: 999)
|
||||
# APP_GID: Group ID for appgroup (default: 999)
|
||||
#
|
||||
# Build example:
|
||||
# docker build --build-arg APP_UID=1000 --build-arg APP_GID=1000 -t vera-ai .
|
||||
#
|
||||
# Runtime environment variables:
|
||||
# TZ: Timezone (default: UTC)
|
||||
# APP_UID: User ID (informational)
|
||||
# APP_GID: Group ID (informational)
|
||||
# VERA_LOG_DIR: Debug log directory (default: /app/logs)
|
||||
|
||||
# Stage 1: Builder
|
||||
FROM python:3.11-slim AS builder
|
||||
@@ -20,9 +11,7 @@ FROM python:3.11-slim AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends build-essential && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install
|
||||
COPY requirements.txt .
|
||||
@@ -38,29 +27,25 @@ ARG APP_UID=999
|
||||
ARG APP_GID=999
|
||||
|
||||
# Create group and user with specified UID/GID
|
||||
RUN groupadd -g ${APP_GID} appgroup && \
|
||||
useradd -u ${APP_UID} -g appgroup -r -m -s /bin/bash appuser
|
||||
RUN groupadd -g ${APP_GID} appgroup && useradd -u ${APP_UID} -g appgroup -r -m -s /bin/bash appuser
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /root/.local /home/appuser/.local
|
||||
ENV PATH=/home/appuser/.local/bin:$PATH
|
||||
|
||||
# Create directories for mounted volumes
|
||||
RUN mkdir -p /app/config /app/prompts /app/static /app/logs && \
|
||||
chown -R ${APP_UID}:${APP_GID} /app
|
||||
RUN mkdir -p /app/config /app/prompts /app/logs && chown -R ${APP_UID}:${APP_GID} /app
|
||||
|
||||
# Copy application code
|
||||
COPY app/ ./app/
|
||||
|
||||
# Copy default config and prompts (can be overridden by volume mounts)
|
||||
COPY config.toml /app/config/config.toml
|
||||
COPY static/curator_prompt.md /app/prompts/curator_prompt.md
|
||||
COPY static/systemprompt.md /app/prompts/systemprompt.md
|
||||
COPY config/config.toml /app/config/config.toml
|
||||
COPY prompts/curator_prompt.md /app/prompts/curator_prompt.md
|
||||
COPY prompts/systemprompt.md /app/prompts/systemprompt.md
|
||||
|
||||
# Create symlinks for backward compatibility
|
||||
RUN ln -sf /app/config/config.toml /app/config.toml && \
|
||||
ln -sf /app/prompts/curator_prompt.md /app/static/curator_prompt.md && \
|
||||
ln -sf /app/prompts/systemprompt.md /app/static/systemprompt.md
|
||||
# Create symlink for config backward compatibility
|
||||
RUN ln -sf /app/config/config.toml /app/config.toml
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R ${APP_UID}:${APP_GID} /app && chmod -R u+rw /app
|
||||
@@ -70,11 +55,10 @@ ENV TZ=UTC
|
||||
|
||||
EXPOSE 11434
|
||||
|
||||
# Health check using Python (no curl needed in slim image)
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:11434/')" || exit 1
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:11434/')" || exit 1
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "11434"]"
|
||||
ENTRYPOINT ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "11434"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# app/config.py
|
||||
import toml
|
||||
import tomllib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
@@ -83,8 +83,8 @@ class Config:
|
||||
config = cls()
|
||||
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
data = toml.load(f)
|
||||
with open(config_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
if "general" in data:
|
||||
config.ollama_host = data["general"].get("ollama_host", config.ollama_host)
|
||||
@@ -113,6 +113,13 @@ class Config:
|
||||
models=cloud_data.get("models", {})
|
||||
)
|
||||
|
||||
if config.cloud.enabled and not config.cloud.api_key:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"Cloud is enabled but API key env var '%s' is not set",
|
||||
config.cloud.api_key_env
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
config = Config.load()
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Context handler - builds 4-layer context for every request."""
|
||||
import httpx
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from .config import Config
|
||||
from .qdrant_service import QdrantService
|
||||
from .utils import count_tokens, truncate_by_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextHandler:
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.qdrant = QdrantService(
|
||||
host=config.qdrant_host,
|
||||
collection=config.qdrant_collection,
|
||||
embedding_model=config.embedding_model,
|
||||
ollama_host=config.ollama_host
|
||||
)
|
||||
self.system_prompt = self._load_system_prompt()
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""Load system prompt from static/systemprompt.md."""
|
||||
try:
|
||||
path = Path(__file__).parent.parent / "static" / "systemprompt.md"
|
||||
return path.read_text().strip()
|
||||
except FileNotFoundError:
|
||||
logger.error("systemprompt.md not found - required file")
|
||||
raise
|
||||
|
||||
async def process(self, messages: List[Dict], model: str, stream: bool = False) -> Dict:
|
||||
"""Process chat request through 4-layer context."""
|
||||
# Get user question (last user message)
|
||||
user_question = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
user_question = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Get messages for semantic search (last N turns)
|
||||
search_messages = []
|
||||
for msg in messages[-self.config.semantic_search_turns:]:
|
||||
if msg.get("role") in ("user", "assistant"):
|
||||
search_messages.append(msg.get("content", ""))
|
||||
|
||||
# Build the 4-layer context messages
|
||||
context_messages = await self.build_context_messages(
|
||||
incoming_system=next((m for m in messages if m.get("role") == "system"), None),
|
||||
user_question=user_question,
|
||||
search_context=" ".join(search_messages)
|
||||
)
|
||||
|
||||
# Forward to Ollama
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.config.ollama_host}/api/chat",
|
||||
json={"model": model, "messages": context_messages, "stream": stream}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
# Store the Q&A turn in Qdrant
|
||||
assistant_msg = result.get("message", {}).get("content", "")
|
||||
await self.qdrant.store_qa_turn(user_question, assistant_msg)
|
||||
|
||||
return result
|
||||
|
||||
def _parse_curated_turn(self, text: str) -> List[Dict]:
|
||||
"""Parse a curated turn into alternating user/assistant messages.
|
||||
|
||||
Input format:
|
||||
User: [question]
|
||||
Assistant: [answer]
|
||||
Timestamp: ISO datetime
|
||||
|
||||
Returns list of message dicts with role and content.
|
||||
"""
|
||||
messages = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
current_role = None
|
||||
current_content = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("User:"):
|
||||
# Save previous content if exists
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
current_role = "user"
|
||||
current_content = [line[5:].strip()] # Remove "User:" prefix
|
||||
elif line.startswith("Assistant:"):
|
||||
# Save previous content if exists
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
current_role = "assistant"
|
||||
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
|
||||
elif line.startswith("Timestamp:"):
|
||||
# Ignore timestamp line
|
||||
continue
|
||||
elif current_role:
|
||||
# Continuation of current message
|
||||
current_content.append(line)
|
||||
|
||||
# Save last message
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
async def build_context_messages(self, incoming_system: Optional[Dict], user_question: str, search_context: str) -> List[Dict]:
|
||||
"""Build 4-layer context messages array."""
|
||||
messages = []
|
||||
token_budget = {
|
||||
"semantic": self.config.semantic_token_budget,
|
||||
"context": self.config.context_token_budget
|
||||
}
|
||||
|
||||
# === LAYER 1: System Prompt (pass through unchanged) ===
|
||||
# DO NOT truncate - preserve system prompt entirely
|
||||
system_content = ""
|
||||
if incoming_system:
|
||||
system_content = incoming_system.get("content", "")
|
||||
logger.info(f"System layer: preserved incoming system {len(system_content)} chars, {count_tokens(system_content)} tokens")
|
||||
|
||||
# Add Vera context info if present (small, just metadata)
|
||||
if self.system_prompt.strip():
|
||||
system_content += "\n\n" + self.system_prompt
|
||||
logger.info(f"System layer: added vera context {len(self.system_prompt)} chars")
|
||||
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# === LAYER 2: Semantic Layer (curated memories) ===
|
||||
# Search for curated blocks only
|
||||
semantic_results = await self.qdrant.semantic_search(
|
||||
query=search_context if search_context else user_question,
|
||||
limit=20,
|
||||
score_threshold=self.config.semantic_score_threshold,
|
||||
entry_type="curated"
|
||||
)
|
||||
|
||||
# Parse curated turns into alternating user/assistant messages
|
||||
semantic_messages = []
|
||||
semantic_tokens_used = 0
|
||||
|
||||
for result in semantic_results:
|
||||
payload = result.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text:
|
||||
parsed = self._parse_curated_turn(text)
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
if semantic_tokens_used + msg_tokens <= token_budget["semantic"]:
|
||||
semantic_messages.append(msg)
|
||||
semantic_tokens_used += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add parsed messages to context
|
||||
for msg in semantic_messages:
|
||||
messages.append(msg)
|
||||
|
||||
if semantic_messages:
|
||||
logger.info(f"Semantic layer: {len(semantic_messages)} messages, ~{semantic_tokens_used} tokens")
|
||||
|
||||
# === LAYER 3: Context Layer (recent turns) ===
|
||||
recent_turns = await self.qdrant.get_recent_turns(limit=50)
|
||||
|
||||
context_messages_parsed = []
|
||||
context_tokens_used = 0
|
||||
|
||||
for turn in reversed(recent_turns): # Oldest first
|
||||
payload = turn.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
entry_type = payload.get("type", "raw")
|
||||
|
||||
if text:
|
||||
# Parse turn into messages
|
||||
parsed = self._parse_curated_turn(text)
|
||||
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
if context_tokens_used + msg_tokens <= token_budget["context"]:
|
||||
context_messages_parsed.append(msg)
|
||||
context_tokens_used += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
for msg in context_messages_parsed:
|
||||
messages.append(msg)
|
||||
|
||||
if context_messages_parsed:
|
||||
logger.info(f"Context layer: {len(context_messages_parsed)} messages, ~{context_tokens_used} tokens")
|
||||
|
||||
# === LAYER 4: Current Question ===
|
||||
messages.append({"role": "user", "content": user_question})
|
||||
|
||||
return messages
|
||||
@@ -6,7 +6,7 @@ The prompt determines behavior based on current date.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
@@ -49,7 +49,7 @@ class Curator:
|
||||
Otherwise runs daily mode (processes recent 24h only).
|
||||
The prompt determines behavior based on current date.
|
||||
"""
|
||||
current_date = datetime.utcnow()
|
||||
current_date = datetime.now(timezone.utc)
|
||||
is_monthly = current_date.day == 1
|
||||
mode = "MONTHLY" if is_monthly else "DAILY"
|
||||
|
||||
@@ -169,9 +169,10 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
|
||||
return True
|
||||
try:
|
||||
mem_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=hours)
|
||||
return mem_time.replace(tzinfo=None) > cutoff
|
||||
except:
|
||||
except (ValueError, TypeError):
|
||||
logger.debug(f"Could not parse timestamp: {timestamp}")
|
||||
return True
|
||||
|
||||
def _format_raw_turns(self, turns: List[Dict]) -> str:
|
||||
@@ -211,7 +212,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
|
||||
result = response.json()
|
||||
return result.get("response", "")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to call LLM: {e}")
|
||||
logger.error(f"LLM call failed: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
def _parse_json_response(self, response: str) -> Optional[Dict]:
|
||||
@@ -222,6 +223,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Direct JSON parse failed, trying brace extraction")
|
||||
pass
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import httpx
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .config import config
|
||||
from .singleton import get_qdrant_service
|
||||
@@ -68,7 +68,7 @@ async def lifespan(app: FastAPI):
|
||||
await qdrant_service.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan)
|
||||
app = FastAPI(title="Vera-AI", version="2.0.4", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@@ -80,7 +80,8 @@ async def health_check():
|
||||
resp = await client.get(f"{config.ollama_host}/api/tags")
|
||||
if resp.status_code == 200:
|
||||
ollama_status = "reachable"
|
||||
except: pass
|
||||
except Exception:
|
||||
logger.warning(f"Failed to reach Ollama at {config.ollama_host}")
|
||||
return {"status": "ok", "ollama": ollama_status}
|
||||
|
||||
|
||||
@@ -95,7 +96,7 @@ async def api_tags():
|
||||
for name in config.cloud.models.keys():
|
||||
data["models"].append({
|
||||
"name": name,
|
||||
"modified_at": "2026-03-25T00:00:00Z",
|
||||
"modified_at": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
"size": 0,
|
||||
"digest": "cloud",
|
||||
"details": {"family": "cloud"}
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
import re
|
||||
import logging
|
||||
import os
|
||||
import portalocker
|
||||
from pathlib import Path
|
||||
from .config import config
|
||||
from .singleton import get_qdrant_service
|
||||
@@ -48,17 +49,17 @@ def debug_log(category: str, message: str, data: dict = None):
|
||||
if not config.debug:
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Create logs directory
|
||||
log_dir = DEBUG_LOG_DIR
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
log_path = log_dir / f"debug_{today}.log"
|
||||
|
||||
entry = {
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
"category": category,
|
||||
"message": message
|
||||
}
|
||||
@@ -66,7 +67,9 @@ def debug_log(category: str, message: str, data: dict = None):
|
||||
entry["data"] = data
|
||||
|
||||
with open(log_path, "a") as f:
|
||||
portalocker.lock(f, portalocker.LOCK_EX)
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
portalocker.unlock(f)
|
||||
|
||||
|
||||
async def handle_chat_non_streaming(body: dict):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Qdrant service for memory storage - ASYNC VERSION."""
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
import logging
|
||||
import httpx
|
||||
@@ -34,6 +34,15 @@ class QdrantService:
|
||||
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
|
||||
)
|
||||
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
|
||||
# Ensure payload index on timestamp for ordered scroll
|
||||
try:
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection,
|
||||
field_name="timestamp",
|
||||
field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
except Exception:
|
||||
pass # Index may already exist
|
||||
self._collection_ensured = True
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
@@ -54,7 +63,7 @@ class QdrantService:
|
||||
point_id = str(uuid.uuid4())
|
||||
embedding = await self.get_embedding(content)
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
text = content
|
||||
if role == "user":
|
||||
text = f"User: {content}"
|
||||
@@ -85,7 +94,7 @@ class QdrantService:
|
||||
"""Store a complete Q&A turn as one document."""
|
||||
await self._ensure_collection()
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
text = f"User: {user_question}\nAssistant: {assistant_answer}\nTimestamp: {timestamp}"
|
||||
|
||||
point_id = str(uuid.uuid4())
|
||||
@@ -105,20 +114,28 @@ class QdrantService:
|
||||
)
|
||||
return point_id
|
||||
|
||||
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]:
|
||||
"""Semantic search for relevant turns, filtered by type."""
|
||||
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated", entry_types: Optional[List[str]] = None) -> List[Dict]:
|
||||
"""Semantic search for relevant turns, filtered by type(s)."""
|
||||
await self._ensure_collection()
|
||||
|
||||
embedding = await self.get_embedding(query)
|
||||
|
||||
if entry_types and len(entry_types) > 1:
|
||||
type_filter = Filter(
|
||||
should=[FieldCondition(key="type", match=MatchValue(value=t)) for t in entry_types]
|
||||
)
|
||||
else:
|
||||
filter_type = entry_types[0] if entry_types else entry_type
|
||||
type_filter = Filter(
|
||||
must=[FieldCondition(key="type", match=MatchValue(value=filter_type))]
|
||||
)
|
||||
|
||||
results = await self.client.query_points(
|
||||
collection_name=self.collection,
|
||||
query=embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="type", match=MatchValue(value=entry_type))]
|
||||
)
|
||||
query_filter=type_filter
|
||||
)
|
||||
|
||||
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
|
||||
@@ -127,20 +144,28 @@ class QdrantService:
|
||||
"""Get recent turns from Qdrant (both raw and curated)."""
|
||||
await self._ensure_collection()
|
||||
|
||||
points, _ = await self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
limit=limit * 2,
|
||||
with_payload=True
|
||||
)
|
||||
try:
|
||||
from qdrant_client.models import OrderBy
|
||||
points, _ = await self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
order_by=OrderBy(key="timestamp", direction="desc")
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: fetch extra points and sort client-side
|
||||
points, _ = await self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
limit=limit * 5,
|
||||
with_payload=True
|
||||
)
|
||||
points = sorted(
|
||||
points,
|
||||
key=lambda p: p.payload.get("timestamp", ""),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
|
||||
# Sort by timestamp descending
|
||||
sorted_points = sorted(
|
||||
points,
|
||||
key=lambda p: p.payload.get("timestamp", ""),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]]
|
||||
return [{"id": str(p.id), "payload": p.payload} for p in points]
|
||||
|
||||
async def delete_points(self, point_ids: List[str]) -> None:
|
||||
"""Delete points by ID."""
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Global singleton instances for Vera-AI."""
|
||||
from typing import Optional
|
||||
from .qdrant_service import QdrantService
|
||||
from .config import config
|
||||
|
||||
_qdrant_service: QdrantService = None
|
||||
_qdrant_service: Optional[QdrantService] = None
|
||||
|
||||
|
||||
def get_qdrant_service() -> QdrantService:
|
||||
|
||||
191
app/utils.py
191
app/utils.py
@@ -1,9 +1,10 @@
|
||||
"""Utility functions for vera-ai."""
|
||||
from .config import config
|
||||
from .singleton import get_qdrant_service
|
||||
import tiktoken
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Use cl100k_base encoding (GPT-4 compatible)
|
||||
@@ -13,24 +14,6 @@ ENCODING = tiktoken.get_encoding("cl100k_base")
|
||||
PROMPTS_DIR = Path(os.environ.get("VERA_PROMPTS_DIR", "/app/prompts"))
|
||||
STATIC_DIR = Path(os.environ.get("VERA_STATIC_DIR", "/app/static"))
|
||||
|
||||
# Global qdrant_service instance for utils
|
||||
_qdrant_service = None
|
||||
|
||||
def get_qdrant_service():
|
||||
"""Get or create the QdrantService singleton."""
|
||||
global _qdrant_service
|
||||
if _qdrant_service is None:
|
||||
from .config import config
|
||||
from .qdrant_service import QdrantService
|
||||
_qdrant_service = QdrantService(
|
||||
host=config.qdrant_host,
|
||||
collection=config.qdrant_collection,
|
||||
embedding_model=config.embedding_model,
|
||||
vector_size=config.vector_size,
|
||||
ollama_host=config.ollama_host
|
||||
)
|
||||
return _qdrant_service
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count tokens in text."""
|
||||
if not text:
|
||||
@@ -56,7 +39,7 @@ def truncate_by_tokens(text: str, max_tokens: int) -> str:
|
||||
|
||||
def filter_memories_by_time(memories: List[Dict], hours: int = 24) -> List[Dict]:
|
||||
"""Filter memories from the last N hours."""
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=hours)
|
||||
filtered = []
|
||||
for mem in memories:
|
||||
ts = mem.get("timestamp")
|
||||
@@ -64,7 +47,7 @@ def filter_memories_by_time(memories: List[Dict], hours: int = 24) -> List[Dict]
|
||||
try:
|
||||
# Parse ISO timestamp
|
||||
if isinstance(ts, str):
|
||||
mem_time = datetime.fromisoformat(ts.replace("Z", "+00:00").replace("+00:00", ""))
|
||||
mem_time = datetime.fromisoformat(ts.replace("Z", "")).replace(tzinfo=None)
|
||||
else:
|
||||
mem_time = ts
|
||||
if mem_time > cutoff:
|
||||
@@ -100,15 +83,6 @@ def merge_memories(memories: List[Dict]) -> Dict:
|
||||
"ids": ids
|
||||
}
|
||||
|
||||
def calculate_token_budget(total_budget: int, system_ratio: float = 0.2,
|
||||
semantic_ratio: float = 0.5, context_ratio: float = 0.3) -> Dict[int, int]:
|
||||
"""Calculate token budgets for each layer."""
|
||||
return {
|
||||
"system": int(total_budget * system_ratio),
|
||||
"semantic": int(total_budget * semantic_ratio),
|
||||
"context": int(total_budget * context_ratio)
|
||||
}
|
||||
|
||||
def load_system_prompt() -> str:
|
||||
"""Load system prompt from prompts directory."""
|
||||
import logging
|
||||
@@ -127,10 +101,70 @@ def load_system_prompt() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def parse_curated_turn(text: str) -> List[Dict]:
|
||||
"""Parse a curated turn into alternating user/assistant messages.
|
||||
|
||||
Input format:
|
||||
User: [question]
|
||||
Assistant: [answer]
|
||||
Timestamp: ISO datetime
|
||||
|
||||
Returns list of message dicts with role and content.
|
||||
Returns empty list if parsing fails.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
messages = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
current_role = None
|
||||
current_content = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("User:"):
|
||||
# Save previous content if exists
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
current_role = "user"
|
||||
current_content = [line[5:].strip()] # Remove "User:" prefix
|
||||
elif line.startswith("Assistant:"):
|
||||
# Save previous content if exists
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
current_role = "assistant"
|
||||
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
|
||||
elif line.startswith("Timestamp:"):
|
||||
# Ignore timestamp line
|
||||
continue
|
||||
elif current_role:
|
||||
# Continuation of current message
|
||||
current_content.append(line)
|
||||
|
||||
# Save last message
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
||||
"""Build 4-layer augmented messages from incoming messages.
|
||||
|
||||
This is a standalone version that can be used by proxy_handler.py.
|
||||
Layer 1: System prompt (preserved from incoming + vera context)
|
||||
Layer 2: Semantic memories (curated, parsed into proper roles)
|
||||
Layer 3: Recent context (raw turns, parsed into proper roles)
|
||||
Layer 4: Current conversation (passed through)
|
||||
"""
|
||||
import logging
|
||||
|
||||
@@ -153,51 +187,110 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
||||
search_context += msg.get("content", "") + " "
|
||||
|
||||
messages = []
|
||||
token_budget = {
|
||||
"semantic": config.semantic_token_budget,
|
||||
"context": config.context_token_budget
|
||||
}
|
||||
|
||||
# === LAYER 1: System Prompt ===
|
||||
system_content = ""
|
||||
# Caller's system message passes through; systemprompt.md appends if non-empty.
|
||||
caller_system = ""
|
||||
for msg in incoming_messages:
|
||||
if msg.get("role") == "system":
|
||||
system_content = msg.get("content", "")
|
||||
caller_system = msg.get("content", "")
|
||||
break
|
||||
|
||||
if system_prompt:
|
||||
system_content += "\n\n" + system_prompt
|
||||
if caller_system and system_prompt:
|
||||
system_content = caller_system + "\n\n" + system_prompt
|
||||
elif caller_system:
|
||||
system_content = caller_system
|
||||
elif system_prompt:
|
||||
system_content = system_prompt
|
||||
else:
|
||||
system_content = ""
|
||||
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens")
|
||||
|
||||
# === LAYER 2: Semantic (curated memories) ===
|
||||
# === LAYER 2: Semantic (curated + raw memories) ===
|
||||
qdrant = get_qdrant_service()
|
||||
semantic_results = await qdrant.semantic_search(
|
||||
query=search_context if search_context else user_question,
|
||||
limit=20,
|
||||
score_threshold=config.semantic_score_threshold,
|
||||
entry_type="curated"
|
||||
entry_types=["curated", "raw"]
|
||||
)
|
||||
|
||||
semantic_tokens = 0
|
||||
semantic_messages = []
|
||||
semantic_tokens_used = 0
|
||||
semantic_ids = set()
|
||||
|
||||
for result in semantic_results:
|
||||
semantic_ids.add(result.get("id"))
|
||||
payload = result.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text and semantic_tokens < config.semantic_token_budget:
|
||||
messages.append({"role": "user", "content": text}) # Add as context
|
||||
semantic_tokens += count_tokens(text)
|
||||
if text:
|
||||
# Parse curated/raw turn into proper user/assistant messages
|
||||
parsed = parse_curated_turn(text)
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
if semantic_tokens_used + msg_tokens <= token_budget["semantic"]:
|
||||
semantic_messages.append(msg)
|
||||
semantic_tokens_used += msg_tokens
|
||||
else:
|
||||
break
|
||||
if semantic_tokens_used >= token_budget["semantic"]:
|
||||
break
|
||||
|
||||
# Add parsed messages to context
|
||||
for msg in semantic_messages:
|
||||
messages.append(msg)
|
||||
|
||||
if semantic_messages:
|
||||
logger.info(f"Layer 2 (semantic): {len(semantic_messages)} messages, ~{semantic_tokens_used} tokens")
|
||||
|
||||
# === LAYER 3: Context (recent turns) ===
|
||||
recent_turns = await qdrant.get_recent_turns(limit=20)
|
||||
recent_turns = await qdrant.get_recent_turns(limit=50)
|
||||
|
||||
context_tokens = 0
|
||||
context_messages = []
|
||||
context_tokens_used = 0
|
||||
|
||||
# Process oldest first for chronological order, skip duplicates from Layer 2
|
||||
for turn in reversed(recent_turns):
|
||||
if turn.get("id") in semantic_ids:
|
||||
continue
|
||||
payload = turn.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text and context_tokens < config.context_token_budget:
|
||||
messages.append({"role": "user", "content": text}) # Add as context
|
||||
context_tokens += count_tokens(text)
|
||||
entry_type = payload.get("type", "raw")
|
||||
|
||||
# === LAYER 4: Current messages (passed through) ===
|
||||
if text:
|
||||
# Parse turn into messages
|
||||
parsed = parse_curated_turn(text)
|
||||
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
if context_tokens_used + msg_tokens <= token_budget["context"]:
|
||||
context_messages.append(msg)
|
||||
context_tokens_used += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if context_tokens_used >= token_budget["context"]:
|
||||
break
|
||||
|
||||
# Add context messages (oldest first maintains conversation order)
|
||||
for msg in context_messages:
|
||||
messages.append(msg)
|
||||
|
||||
if context_messages:
|
||||
logger.info(f"Layer 3 (context): {len(context_messages)} messages, ~{context_tokens_used} tokens")
|
||||
|
||||
# === LAYER 4: Current conversation ===
|
||||
for msg in incoming_messages:
|
||||
if msg.get("role") != "system": # Do not duplicate system
|
||||
if msg.get("role") != "system": # System already handled in Layer 1
|
||||
messages.append(msg)
|
||||
|
||||
logger.info(f"Layer 4 (current): {len([m for m in incoming_messages if m.get('role') != 'system'])} messages")
|
||||
|
||||
return messages
|
||||
@@ -6,14 +6,11 @@ embedding_model = "snowflake-arctic-embed2"
|
||||
debug = false
|
||||
|
||||
[layers]
|
||||
# Note: system_token_budget removed - system prompt is never truncated
|
||||
semantic_token_budget = 25000
|
||||
context_token_budget = 22000
|
||||
semantic_search_turns = 2
|
||||
semantic_score_threshold = 0.6
|
||||
|
||||
[curator]
|
||||
# Daily curation: processes recent 24h of raw memories
|
||||
# Monthly mode is detected automatically by curator_prompt.md (day 01)
|
||||
run_time = "02:00"
|
||||
curator_model = "gpt-oss:120b"
|
||||
|
||||
2
pytest.ini
Normal file
2
pytest.ini
Normal file
@@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
@@ -1,8 +1,11 @@
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
httpx>=0.25.0
|
||||
qdrant-client>=1.6.0
|
||||
ollama>=0.1.0
|
||||
toml>=0.10.2
|
||||
tiktoken>=0.5.0
|
||||
apscheduler>=3.10.0
|
||||
fastapi==0.135.2
|
||||
uvicorn[standard]==0.42.0
|
||||
httpx==0.28.1
|
||||
qdrant-client==1.17.1
|
||||
ollama==0.6.1
|
||||
tiktoken==0.12.0
|
||||
apscheduler==3.11.2
|
||||
portalocker==3.2.0
|
||||
pytest==9.0.2
|
||||
pytest-asyncio==1.3.0
|
||||
pytest-cov==7.1.0
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Test package
|
||||
62
tests/conftest.py
Normal file
62
tests/conftest.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Shared test fixtures using production-realistic data."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from app.config import Config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def production_config():
|
||||
"""Config matching production deployment on deb8."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.ollama_host = "http://10.0.0.10:11434"
|
||||
config.qdrant_host = "http://10.0.0.22:6333"
|
||||
config.qdrant_collection = "memories"
|
||||
config.embedding_model = "snowflake-arctic-embed2"
|
||||
config.semantic_token_budget = 25000
|
||||
config.context_token_budget = 22000
|
||||
config.semantic_search_turns = 2
|
||||
config.semantic_score_threshold = 0.6
|
||||
config.run_time = "02:00"
|
||||
config.curator_model = "gpt-oss:120b"
|
||||
config.debug = False
|
||||
config.vector_size = 1024
|
||||
config.cloud = MagicMock()
|
||||
config.cloud.enabled = False
|
||||
config.cloud.models = {}
|
||||
config.cloud.get_cloud_model.return_value = None
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_qdrant_raw_payload():
|
||||
"""Sample raw payload from production Qdrant."""
|
||||
return {
|
||||
"type": "raw",
|
||||
"text": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z",
|
||||
"timestamp": "2026-03-27T12:50:37.451593Z",
|
||||
"role": "qa",
|
||||
"content": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ollama_models():
|
||||
"""Model list from production Ollama."""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"name": "snowflake-arctic-embed2:latest",
|
||||
"model": "snowflake-arctic-embed2:latest",
|
||||
"modified_at": "2026-02-16T16:43:44Z",
|
||||
"size": 1160296718,
|
||||
"details": {"family": "bert", "parameter_size": "566.70M", "quantization_level": "F16"}
|
||||
},
|
||||
{
|
||||
"name": "gpt-oss:120b",
|
||||
"model": "gpt-oss:120b",
|
||||
"modified_at": "2026-03-11T12:45:48Z",
|
||||
"size": 65369818941,
|
||||
"details": {"family": "gptoss", "parameter_size": "116.8B", "quantization_level": "MXFP4"}
|
||||
}
|
||||
]
|
||||
}
|
||||
174
tests/test_config.py
Normal file
174
tests/test_config.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Tests for configuration."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from app.config import Config, EMBEDDING_DIMS
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Tests for Config class."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Config should have sensible defaults."""
|
||||
config = Config()
|
||||
assert config.ollama_host == "http://10.0.0.10:11434"
|
||||
assert config.qdrant_host == "http://10.0.0.22:6333"
|
||||
assert config.qdrant_collection == "memories"
|
||||
assert config.embedding_model == "snowflake-arctic-embed2"
|
||||
|
||||
def test_vector_size_property(self):
|
||||
"""Vector size should match embedding model."""
|
||||
config = Config(embedding_model="snowflake-arctic-embed2")
|
||||
assert config.vector_size == 1024
|
||||
|
||||
def test_vector_size_fallback(self):
|
||||
"""Unknown model should default to 1024."""
|
||||
config = Config(embedding_model="unknown-model")
|
||||
assert config.vector_size == 1024
|
||||
|
||||
|
||||
class TestEmbeddingDims:
|
||||
"""Tests for embedding dimensions mapping."""
|
||||
|
||||
def test_snowflake_arctic_embed2(self):
|
||||
"""snowflake-arctic-embed2 should have 1024 dimensions."""
|
||||
assert EMBEDDING_DIMS["snowflake-arctic-embed2"] == 1024
|
||||
|
||||
def test_nomic_embed_text(self):
|
||||
"""nomic-embed-text should have 768 dimensions."""
|
||||
assert EMBEDDING_DIMS["nomic-embed-text"] == 768
|
||||
|
||||
def test_mxbai_embed_large(self):
|
||||
"""mxbai-embed-large should have 1024 dimensions."""
|
||||
assert EMBEDDING_DIMS["mxbai-embed-large"] == 1024
|
||||
|
||||
|
||||
class TestConfigLoad:
|
||||
"""Tests for Config.load() with real TOML content."""
|
||||
|
||||
def test_load_from_explicit_path(self, tmp_path):
|
||||
"""Config.load() should parse a TOML file at an explicit path."""
|
||||
from app.config import Config
|
||||
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text(
|
||||
'[general]\n'
|
||||
'ollama_host = "http://localhost:11434"\n'
|
||||
'qdrant_host = "http://localhost:6333"\n'
|
||||
'qdrant_collection = "test_memories"\n'
|
||||
)
|
||||
cfg = Config.load(str(config_file))
|
||||
assert cfg.ollama_host == "http://localhost:11434"
|
||||
assert cfg.qdrant_host == "http://localhost:6333"
|
||||
assert cfg.qdrant_collection == "test_memories"
|
||||
|
||||
def test_load_layers_section(self, tmp_path):
|
||||
"""Config.load() should parse [layers] section correctly."""
|
||||
from app.config import Config
|
||||
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text(
|
||||
'[layers]\n'
|
||||
'semantic_token_budget = 5000\n'
|
||||
'context_token_budget = 3000\n'
|
||||
'semantic_score_threshold = 0.75\n'
|
||||
)
|
||||
cfg = Config.load(str(config_file))
|
||||
assert cfg.semantic_token_budget == 5000
|
||||
assert cfg.context_token_budget == 3000
|
||||
assert cfg.semantic_score_threshold == 0.75
|
||||
|
||||
def test_load_curator_section(self, tmp_path):
|
||||
"""Config.load() should parse [curator] section correctly."""
|
||||
from app.config import Config
|
||||
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text(
|
||||
'[curator]\n'
|
||||
'run_time = "03:30"\n'
|
||||
'curator_model = "mixtral:8x22b"\n'
|
||||
)
|
||||
cfg = Config.load(str(config_file))
|
||||
assert cfg.run_time == "03:30"
|
||||
assert cfg.curator_model == "mixtral:8x22b"
|
||||
|
||||
def test_load_cloud_section(self, tmp_path):
|
||||
"""Config.load() should parse [cloud] section correctly."""
|
||||
from app.config import Config
|
||||
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text(
|
||||
'[cloud]\n'
|
||||
'enabled = true\n'
|
||||
'api_base = "https://openrouter.ai/api/v1"\n'
|
||||
'api_key_env = "MY_API_KEY"\n'
|
||||
'\n'
|
||||
'[cloud.models]\n'
|
||||
'"gpt-oss:120b" = "openai/gpt-4o"\n'
|
||||
)
|
||||
cfg = Config.load(str(config_file))
|
||||
assert cfg.cloud.enabled is True
|
||||
assert cfg.cloud.api_base == "https://openrouter.ai/api/v1"
|
||||
assert cfg.cloud.api_key_env == "MY_API_KEY"
|
||||
assert "gpt-oss:120b" in cfg.cloud.models
|
||||
|
||||
def test_load_nonexistent_file_returns_defaults(self, tmp_path):
|
||||
"""Config.load() with missing file should fall back to defaults."""
|
||||
from app.config import Config
|
||||
import os
|
||||
|
||||
# Point config dir to a place with no config.toml
|
||||
os.environ["VERA_CONFIG_DIR"] = str(tmp_path / "noconfig")
|
||||
try:
|
||||
cfg = Config.load(str(tmp_path / "nonexistent.toml"))
|
||||
finally:
|
||||
del os.environ["VERA_CONFIG_DIR"]
|
||||
|
||||
assert cfg.ollama_host == "http://10.0.0.10:11434"
|
||||
|
||||
|
||||
class TestCloudConfig:
|
||||
"""Tests for CloudConfig helper methods."""
|
||||
|
||||
def test_is_cloud_model_true(self):
|
||||
"""is_cloud_model returns True for registered model name."""
|
||||
from app.config import CloudConfig
|
||||
|
||||
cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"})
|
||||
assert cc.is_cloud_model("gpt-oss:120b") is True
|
||||
|
||||
def test_is_cloud_model_false(self):
|
||||
"""is_cloud_model returns False for unknown model name."""
|
||||
from app.config import CloudConfig
|
||||
|
||||
cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"})
|
||||
assert cc.is_cloud_model("llama3:70b") is False
|
||||
|
||||
def test_get_cloud_model_existing(self):
|
||||
"""get_cloud_model returns mapped cloud model ID."""
|
||||
from app.config import CloudConfig
|
||||
|
||||
cc = CloudConfig(enabled=True, models={"gpt-oss:120b": "openai/gpt-4o"})
|
||||
assert cc.get_cloud_model("gpt-oss:120b") == "openai/gpt-4o"
|
||||
|
||||
def test_get_cloud_model_missing(self):
|
||||
"""get_cloud_model returns None for unknown name."""
|
||||
from app.config import CloudConfig
|
||||
|
||||
cc = CloudConfig(enabled=True, models={})
|
||||
assert cc.get_cloud_model("unknown") is None
|
||||
|
||||
def test_api_key_from_env(self, monkeypatch):
|
||||
"""api_key property reads from environment variable."""
|
||||
from app.config import CloudConfig
|
||||
|
||||
monkeypatch.setenv("MY_TEST_KEY", "sk-secret")
|
||||
cc = CloudConfig(api_key_env="MY_TEST_KEY")
|
||||
assert cc.api_key == "sk-secret"
|
||||
|
||||
def test_api_key_missing_from_env(self, monkeypatch):
|
||||
"""api_key returns None when env var is not set."""
|
||||
from app.config import CloudConfig
|
||||
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
cc = CloudConfig(api_key_env="OPENROUTER_API_KEY")
|
||||
assert cc.api_key is None
|
||||
490
tests/test_curator.py
Normal file
490
tests/test_curator.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""Tests for Curator class methods — no live LLM or Qdrant required."""
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
def make_curator():
|
||||
"""Return a Curator instance with load_curator_prompt mocked and mock QdrantService."""
|
||||
from app.curator import Curator
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
|
||||
with patch("app.curator.load_curator_prompt", return_value="Curate memories. Date: {CURRENT_DATE}"):
|
||||
curator = Curator(
|
||||
qdrant_service=mock_qdrant,
|
||||
model="test-model",
|
||||
ollama_host="http://localhost:11434",
|
||||
)
|
||||
|
||||
return curator, mock_qdrant
|
||||
|
||||
|
||||
class TestParseJsonResponse:
|
||||
"""Tests for Curator._parse_json_response."""
|
||||
|
||||
def test_direct_valid_json(self):
|
||||
"""Valid JSON string parsed directly."""
|
||||
curator, _ = make_curator()
|
||||
payload = {"new_curated_turns": [], "deletions": []}
|
||||
result = curator._parse_json_response(json.dumps(payload))
|
||||
assert result == payload
|
||||
|
||||
def test_json_in_code_block(self):
|
||||
"""JSON wrapped in ```json ... ``` code fence is extracted."""
|
||||
curator, _ = make_curator()
|
||||
payload = {"summary": "done"}
|
||||
response = f"```json\n{json.dumps(payload)}\n```"
|
||||
result = curator._parse_json_response(response)
|
||||
assert result == payload
|
||||
|
||||
def test_json_embedded_in_text(self):
|
||||
"""JSON embedded after prose text is extracted via brace scan."""
|
||||
curator, _ = make_curator()
|
||||
payload = {"new_curated_turns": [{"content": "Q: hi\nA: there"}]}
|
||||
response = f"Here is the result:\n{json.dumps(payload)}\nThat's all."
|
||||
result = curator._parse_json_response(response)
|
||||
assert result is not None
|
||||
assert "new_curated_turns" in result
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
"""Empty response returns None."""
|
||||
curator, _ = make_curator()
|
||||
result = curator._parse_json_response("")
|
||||
assert result is None
|
||||
|
||||
def test_malformed_json_returns_none(self):
|
||||
"""Completely invalid text returns None."""
|
||||
curator, _ = make_curator()
|
||||
result = curator._parse_json_response("this is not json at all !!!")
|
||||
assert result is None
|
||||
|
||||
def test_json_in_plain_code_block(self):
|
||||
"""JSON in ``` (no language tag) code fence is extracted."""
|
||||
curator, _ = make_curator()
|
||||
payload = {"permanent_rules": []}
|
||||
response = f"```\n{json.dumps(payload)}\n```"
|
||||
result = curator._parse_json_response(response)
|
||||
assert result == payload
|
||||
|
||||
|
||||
class TestIsRecent:
|
||||
"""Tests for Curator._is_recent."""
|
||||
|
||||
def test_memory_within_window(self):
|
||||
"""Memory timestamped 1 hour ago is recent (within 24h)."""
|
||||
curator, _ = make_curator()
|
||||
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)).isoformat() + "Z"
|
||||
memory = {"timestamp": ts}
|
||||
assert curator._is_recent(memory, hours=24) is True
|
||||
|
||||
def test_memory_outside_window(self):
|
||||
"""Memory timestamped 48 hours ago is not recent."""
|
||||
curator, _ = make_curator()
|
||||
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=48)).isoformat() + "Z"
|
||||
memory = {"timestamp": ts}
|
||||
assert curator._is_recent(memory, hours=24) is False
|
||||
|
||||
def test_no_timestamp_returns_true(self):
|
||||
"""Memory without timestamp is treated as recent (safe default)."""
|
||||
curator, _ = make_curator()
|
||||
memory = {}
|
||||
assert curator._is_recent(memory, hours=24) is True
|
||||
|
||||
def test_empty_timestamp_returns_true(self):
|
||||
"""Memory with empty timestamp string is treated as recent."""
|
||||
curator, _ = make_curator()
|
||||
memory = {"timestamp": ""}
|
||||
assert curator._is_recent(memory, hours=24) is True
|
||||
|
||||
def test_unparseable_timestamp_returns_true(self):
|
||||
"""Memory with garbage timestamp is treated as recent (safe default)."""
|
||||
curator, _ = make_curator()
|
||||
memory = {"timestamp": "not-a-date"}
|
||||
assert curator._is_recent(memory, hours=24) is True
|
||||
|
||||
def test_boundary_edge_just_inside(self):
|
||||
"""Memory at exactly hours-1 minutes ago should be recent."""
|
||||
curator, _ = make_curator()
|
||||
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=23, minutes=59)).isoformat() + "Z"
|
||||
memory = {"timestamp": ts}
|
||||
assert curator._is_recent(memory, hours=24) is True
|
||||
|
||||
|
||||
class TestFormatRawTurns:
|
||||
"""Tests for Curator._format_raw_turns."""
|
||||
|
||||
def test_empty_list(self):
|
||||
"""Empty input produces empty string."""
|
||||
curator, _ = make_curator()
|
||||
result = curator._format_raw_turns([])
|
||||
assert result == ""
|
||||
|
||||
def test_single_turn_header(self):
|
||||
"""Single turn has RAW TURN 1 header and turn ID."""
|
||||
curator, _ = make_curator()
|
||||
turns = [{"id": "abc123", "text": "User: hello\nAssistant: hi"}]
|
||||
result = curator._format_raw_turns(turns)
|
||||
assert "RAW TURN 1" in result
|
||||
assert "abc123" in result
|
||||
assert "hello" in result
|
||||
|
||||
def test_multiple_turns_numbered(self):
|
||||
"""Multiple turns are numbered sequentially."""
|
||||
curator, _ = make_curator()
|
||||
turns = [
|
||||
{"id": "id1", "text": "turn one"},
|
||||
{"id": "id2", "text": "turn two"},
|
||||
{"id": "id3", "text": "turn three"},
|
||||
]
|
||||
result = curator._format_raw_turns(turns)
|
||||
assert "RAW TURN 1" in result
|
||||
assert "RAW TURN 2" in result
|
||||
assert "RAW TURN 3" in result
|
||||
|
||||
def test_missing_id_uses_unknown(self):
|
||||
"""Turn without id field shows 'unknown' placeholder."""
|
||||
curator, _ = make_curator()
|
||||
turns = [{"text": "some text"}]
|
||||
result = curator._format_raw_turns(turns)
|
||||
assert "unknown" in result
|
||||
|
||||
|
||||
class TestAppendRuleToFile:
|
||||
"""Tests for Curator._append_rule_to_file (filesystem via tmp_path)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_appends_to_existing_file(self, tmp_path):
|
||||
"""Rule is appended to existing file."""
|
||||
import app.curator as curator_module
|
||||
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
prompts_dir.mkdir()
|
||||
target = prompts_dir / "systemprompt.md"
|
||||
target.write_text("# Existing content\n")
|
||||
|
||||
with patch("app.curator.load_curator_prompt", return_value="prompt {CURRENT_DATE}"), \
|
||||
patch.object(curator_module, "PROMPTS_DIR", prompts_dir):
|
||||
|
||||
from app.curator import Curator
|
||||
mock_qdrant = MagicMock()
|
||||
curator = Curator(mock_qdrant, model="m", ollama_host="http://x")
|
||||
await curator._append_rule_to_file("systemprompt.md", "Always be concise.")
|
||||
|
||||
content = target.read_text()
|
||||
assert "Always be concise." in content
|
||||
assert "# Existing content" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_file_if_missing(self, tmp_path):
|
||||
"""Rule is written to a new file if none existed."""
|
||||
import app.curator as curator_module
|
||||
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
prompts_dir.mkdir()
|
||||
|
||||
with patch("app.curator.load_curator_prompt", return_value="prompt {CURRENT_DATE}"), \
|
||||
patch.object(curator_module, "PROMPTS_DIR", prompts_dir):
|
||||
|
||||
from app.curator import Curator
|
||||
mock_qdrant = MagicMock()
|
||||
curator = Curator(mock_qdrant, model="m", ollama_host="http://x")
|
||||
await curator._append_rule_to_file("newfile.md", "New rule here.")
|
||||
|
||||
target = prompts_dir / "newfile.md"
|
||||
assert target.exists()
|
||||
assert "New rule here." in target.read_text()
|
||||
|
||||
|
||||
class TestFormatExistingMemories:
|
||||
"""Tests for Curator._format_existing_memories."""
|
||||
|
||||
def test_empty_list_returns_no_memories_message(self):
|
||||
"""Empty list returns a 'no memories' message."""
|
||||
curator, _ = make_curator()
|
||||
result = curator._format_existing_memories([])
|
||||
assert "No existing curated memories" in result
|
||||
|
||||
def test_single_memory_formatted(self):
|
||||
"""Single memory text is included in output."""
|
||||
curator, _ = make_curator()
|
||||
memories = [{"text": "User: hello\nAssistant: hi there"}]
|
||||
result = curator._format_existing_memories(memories)
|
||||
assert "hello" in result
|
||||
assert "hi there" in result
|
||||
|
||||
def test_limits_to_last_20(self):
|
||||
"""Only last 20 memories are included."""
|
||||
curator, _ = make_curator()
|
||||
memories = [{"text": f"memory {i}"} for i in range(30)]
|
||||
result = curator._format_existing_memories(memories)
|
||||
# Should contain memory 10-29 (last 20), not memory 0-9
|
||||
assert "memory 29" in result
|
||||
assert "memory 10" in result
|
||||
|
||||
|
||||
class TestCallLlm:
|
||||
"""Tests for Curator._call_llm."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_returns_response(self):
|
||||
"""_call_llm returns the response text from Ollama."""
|
||||
curator, _ = make_curator()
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"response": "some LLM output"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
result = await curator._call_llm("test prompt")
|
||||
|
||||
assert result == "some LLM output"
|
||||
call_kwargs = mock_client.post.call_args
|
||||
assert "test-model" in call_kwargs[1]["json"]["model"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_returns_empty_on_error(self):
|
||||
"""_call_llm returns empty string when Ollama errors."""
|
||||
curator, _ = make_curator()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(side_effect=Exception("connection refused"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
result = await curator._call_llm("test prompt")
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestCuratorRun:
|
||||
"""Tests for Curator.run() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_no_raw_memories_exits_early(self):
|
||||
"""run() exits early when no raw memories found."""
|
||||
curator, mock_qdrant = make_curator()
|
||||
|
||||
# Mock scroll to return no points
|
||||
mock_qdrant.client = AsyncMock()
|
||||
mock_qdrant.client.scroll = AsyncMock(return_value=([], None))
|
||||
mock_qdrant.collection = "memories"
|
||||
|
||||
await curator.run()
|
||||
# Should not call LLM since there are no raw memories
|
||||
# If it got here without error, that's success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_processes_raw_memories(self):
|
||||
"""run() processes raw memories and stores curated results."""
|
||||
curator, mock_qdrant = make_curator()
|
||||
|
||||
# Create mock points
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "point-1"
|
||||
mock_point.payload = {
|
||||
"type": "raw",
|
||||
"text": "User: hello\nAssistant: hi",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
}
|
||||
|
||||
mock_qdrant.client = AsyncMock()
|
||||
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||
mock_qdrant.collection = "memories"
|
||||
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||
mock_qdrant.delete_points = AsyncMock()
|
||||
|
||||
llm_response = json.dumps({
|
||||
"new_curated_turns": [{"content": "User: hello\nAssistant: hi"}],
|
||||
"permanent_rules": [],
|
||||
"deletions": [],
|
||||
"summary": "Curated one turn"
|
||||
})
|
||||
|
||||
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)):
|
||||
await curator.run()
|
||||
|
||||
mock_qdrant.store_turn.assert_called_once()
|
||||
mock_qdrant.delete_points.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_monthly_mode_on_day_01(self):
|
||||
"""run() uses monthly mode on day 01, processing all raw memories."""
|
||||
curator, mock_qdrant = make_curator()
|
||||
|
||||
# Create a mock point with an old timestamp (outside 24h window)
|
||||
old_ts = (datetime.now(timezone.utc) - timedelta(hours=72)).isoformat().replace("+00:00", "Z")
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "old-point"
|
||||
mock_point.payload = {
|
||||
"type": "raw",
|
||||
"text": "User: old question\nAssistant: old answer",
|
||||
"timestamp": old_ts,
|
||||
}
|
||||
|
||||
mock_qdrant.client = AsyncMock()
|
||||
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||
mock_qdrant.collection = "memories"
|
||||
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||
mock_qdrant.delete_points = AsyncMock()
|
||||
|
||||
llm_response = json.dumps({
|
||||
"new_curated_turns": [],
|
||||
"permanent_rules": [],
|
||||
"deletions": [],
|
||||
"summary": "Nothing to curate"
|
||||
})
|
||||
|
||||
# Mock day 01
|
||||
mock_now = datetime(2026, 4, 1, 2, 0, 0, tzinfo=timezone.utc)
|
||||
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \
|
||||
patch("app.curator.datetime") as mock_dt:
|
||||
mock_dt.now.return_value = mock_now
|
||||
mock_dt.fromisoformat = datetime.fromisoformat
|
||||
mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw)
|
||||
await curator.run()
|
||||
|
||||
# In monthly mode, even old memories are processed, so LLM should be called
|
||||
# and delete_points should be called for the raw memory
|
||||
mock_qdrant.delete_points.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_permanent_rules(self):
|
||||
"""run() appends permanent rules to prompt files."""
|
||||
curator, mock_qdrant = make_curator()
|
||||
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "point-1"
|
||||
mock_point.payload = {
|
||||
"type": "raw",
|
||||
"text": "User: remember this\nAssistant: ok",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
}
|
||||
|
||||
mock_qdrant.client = AsyncMock()
|
||||
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||
mock_qdrant.collection = "memories"
|
||||
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||
mock_qdrant.delete_points = AsyncMock()
|
||||
|
||||
llm_response = json.dumps({
|
||||
"new_curated_turns": [],
|
||||
"permanent_rules": [{"rule": "Always be concise.", "target_file": "systemprompt.md"}],
|
||||
"deletions": [],
|
||||
"summary": "Added a rule"
|
||||
})
|
||||
|
||||
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \
|
||||
patch.object(curator, "_append_rule_to_file", AsyncMock()) as mock_append:
|
||||
await curator.run()
|
||||
|
||||
mock_append.assert_called_once_with("systemprompt.md", "Always be concise.")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_deletions(self):
|
||||
"""run() deletes specified point IDs when they exist in the database."""
|
||||
curator, mock_qdrant = make_curator()
|
||||
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "point-1"
|
||||
mock_point.payload = {
|
||||
"type": "raw",
|
||||
"text": "User: delete me\nAssistant: ok",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
}
|
||||
|
||||
mock_qdrant.client = AsyncMock()
|
||||
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||
mock_qdrant.collection = "memories"
|
||||
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||
mock_qdrant.delete_points = AsyncMock()
|
||||
|
||||
llm_response = json.dumps({
|
||||
"new_curated_turns": [],
|
||||
"permanent_rules": [],
|
||||
"deletions": ["point-1"],
|
||||
"summary": "Deleted one"
|
||||
})
|
||||
|
||||
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)):
|
||||
await curator.run()
|
||||
|
||||
# delete_points should be called at least twice: once for valid deletions, once for processed raw
|
||||
assert mock_qdrant.delete_points.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_llm_parse_failure(self):
|
||||
"""run() handles LLM returning unparseable response gracefully."""
|
||||
curator, mock_qdrant = make_curator()
|
||||
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "point-1"
|
||||
mock_point.payload = {
|
||||
"type": "raw",
|
||||
"text": "User: test\nAssistant: ok",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
}
|
||||
|
||||
mock_qdrant.client = AsyncMock()
|
||||
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||
mock_qdrant.collection = "memories"
|
||||
|
||||
with patch.object(curator, "_call_llm", AsyncMock(return_value="not json at all!!!")):
|
||||
# Should not raise - just return early
|
||||
await curator.run()
|
||||
|
||||
# store_turn should NOT be called since parsing failed
|
||||
mock_qdrant.store_turn = AsyncMock()
|
||||
mock_qdrant.store_turn.assert_not_called()
|
||||
|
||||
|
||||
class TestLoadCuratorPrompt:
|
||||
"""Tests for load_curator_prompt function."""
|
||||
|
||||
def test_loads_from_prompts_dir(self, tmp_path):
|
||||
"""load_curator_prompt loads from PROMPTS_DIR."""
|
||||
import app.curator as curator_module
|
||||
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
prompts_dir.mkdir()
|
||||
(prompts_dir / "curator_prompt.md").write_text("Test curator prompt")
|
||||
|
||||
with patch.object(curator_module, "PROMPTS_DIR", prompts_dir):
|
||||
from app.curator import load_curator_prompt
|
||||
result = load_curator_prompt()
|
||||
|
||||
assert result == "Test curator prompt"
|
||||
|
||||
def test_falls_back_to_static_dir(self, tmp_path):
|
||||
"""load_curator_prompt falls back to STATIC_DIR."""
|
||||
import app.curator as curator_module
|
||||
|
||||
prompts_dir = tmp_path / "prompts" # does not exist
|
||||
static_dir = tmp_path / "static"
|
||||
static_dir.mkdir()
|
||||
(static_dir / "curator_prompt.md").write_text("Static prompt")
|
||||
|
||||
with patch.object(curator_module, "PROMPTS_DIR", prompts_dir), \
|
||||
patch.object(curator_module, "STATIC_DIR", static_dir):
|
||||
from app.curator import load_curator_prompt
|
||||
result = load_curator_prompt()
|
||||
|
||||
assert result == "Static prompt"
|
||||
|
||||
def test_raises_when_not_found(self, tmp_path):
|
||||
"""load_curator_prompt raises FileNotFoundError when file missing."""
|
||||
import app.curator as curator_module
|
||||
|
||||
with patch.object(curator_module, "PROMPTS_DIR", tmp_path / "nope"), \
|
||||
patch.object(curator_module, "STATIC_DIR", tmp_path / "also_nope"):
|
||||
from app.curator import load_curator_prompt
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_curator_prompt()
|
||||
431
tests/test_integration.py
Normal file
431
tests/test_integration.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""Integration tests — FastAPI app via httpx.AsyncClient test transport.
|
||||
|
||||
All external I/O (Ollama, Qdrant) is mocked. No live services required.
|
||||
"""
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_qdrant():
|
||||
"""Return a fully-mocked QdrantService."""
|
||||
mock = MagicMock()
|
||||
mock._ensure_collection = AsyncMock()
|
||||
mock.semantic_search = AsyncMock(return_value=[])
|
||||
mock.get_recent_turns = AsyncMock(return_value=[])
|
||||
mock.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||
mock.close = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
def _ollama_tags_response():
|
||||
return {"models": [{"name": "llama3", "size": 0}]}
|
||||
|
||||
|
||||
def _ollama_chat_response(content: str = "Hello from Ollama"):
|
||||
return {
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"done": True,
|
||||
"model": "llama3",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_qdrant():
|
||||
return _make_mock_qdrant()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_mocks(mock_qdrant, tmp_path):
|
||||
"""Return the FastAPI app with lifespan mocked (no real Qdrant/scheduler)."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Minimal curator prompt
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
prompts_dir.mkdir()
|
||||
(prompts_dir / "curator_prompt.md").write_text("Curate. Date: {CURRENT_DATE}")
|
||||
(prompts_dir / "systemprompt.md").write_text("You are Vera.")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_lifespan(app):
|
||||
yield
|
||||
|
||||
import app.main as main_module
|
||||
|
||||
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}), \
|
||||
patch("app.main.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("app.singleton.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("app.main.Curator") as MockCurator, \
|
||||
patch("app.main.scheduler") as mock_scheduler:
|
||||
|
||||
mock_scheduler.add_job = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
|
||||
mock_curator_instance = MagicMock()
|
||||
mock_curator_instance.run = AsyncMock()
|
||||
MockCurator.return_value = mock_curator_instance
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Import fresh — use the real routes but swap lifespan
|
||||
from app.main import app as vera_app
|
||||
vera_app.router.lifespan_context = fake_lifespan
|
||||
|
||||
yield vera_app, mock_qdrant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHealthCheck:
|
||||
def test_health_ollama_reachable(self, app_with_mocks):
|
||||
"""GET / returns status ok and ollama=reachable when Ollama is up."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
vera_app, mock_qdrant = app_with_mocks
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
with TestClient(vera_app, raise_server_exceptions=True) as client:
|
||||
resp = client.get("/")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert body["ollama"] == "reachable"
|
||||
|
||||
def test_health_ollama_unreachable(self, app_with_mocks):
|
||||
"""GET / returns ollama=unreachable when Ollama is down."""
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
vera_app, _ = app_with_mocks
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.get = AsyncMock(side_effect=httpx.ConnectError("refused"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
with TestClient(vera_app, raise_server_exceptions=True) as client:
|
||||
resp = client.get("/")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ollama"] == "unreachable"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /api/tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestApiTags:
|
||||
def test_returns_model_list(self, app_with_mocks):
|
||||
"""GET /api/tags proxies Ollama tags."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
vera_app, _ = app_with_mocks
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _ollama_tags_response()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
with TestClient(vera_app) as client:
|
||||
resp = client.get("/api/tags")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "models" in data
|
||||
assert any(m["name"] == "llama3" for m in data["models"])
|
||||
|
||||
def test_cloud_models_injected(self, tmp_path):
|
||||
"""Cloud models appear in /api/tags when cloud is enabled."""
|
||||
from fastapi.testclient import TestClient
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
prompts_dir.mkdir()
|
||||
(prompts_dir / "curator_prompt.md").write_text("Curate.")
|
||||
(prompts_dir / "systemprompt.md").write_text("")
|
||||
|
||||
mock_qdrant = _make_mock_qdrant()
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_lifespan(app):
|
||||
yield
|
||||
|
||||
from app.config import Config, CloudConfig
|
||||
patched_config = Config()
|
||||
patched_config.cloud = CloudConfig(
|
||||
enabled=True,
|
||||
models={"gpt-oss:120b": "openai/gpt-4o"},
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"models": []}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
import app.main as main_module
|
||||
|
||||
with patch.dict(os.environ, {"VERA_PROMPTS_DIR": str(prompts_dir)}), \
|
||||
patch("app.main.config", patched_config), \
|
||||
patch("app.main.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("app.main.scheduler") as mock_scheduler, \
|
||||
patch("app.main.Curator") as MockCurator:
|
||||
|
||||
mock_scheduler.add_job = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
mock_curator_instance = MagicMock()
|
||||
mock_curator_instance.run = AsyncMock()
|
||||
MockCurator.return_value = mock_curator_instance
|
||||
|
||||
from app.main import app as vera_app
|
||||
vera_app.router.lifespan_context = fake_lifespan
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
with TestClient(vera_app) as client:
|
||||
resp = client.get("/api/tags")
|
||||
|
||||
data = resp.json()
|
||||
names = [m["name"] for m in data["models"]]
|
||||
assert "gpt-oss:120b" in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/chat (non-streaming)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestApiChatNonStreaming:
|
||||
def test_non_streaming_round_trip(self, app_with_mocks):
|
||||
"""POST /api/chat with stream=False returns Ollama response."""
|
||||
from fastapi.testclient import TestClient
|
||||
import app.utils as utils_module
|
||||
import app.proxy_handler as ph_module
|
||||
|
||||
vera_app, mock_qdrant = app_with_mocks
|
||||
|
||||
ollama_data = _ollama_chat_response("The answer is 42.")
|
||||
|
||||
mock_post_resp = MagicMock()
|
||||
mock_post_resp.json.return_value = ollama_data
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_post_resp)
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
|
||||
with TestClient(vera_app) as client:
|
||||
resp = client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "What is the answer?"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["message"]["content"] == "The answer is 42."
|
||||
|
||||
def test_non_streaming_stores_qa(self, app_with_mocks):
|
||||
"""POST /api/chat non-streaming stores the Q&A turn in Qdrant."""
|
||||
from fastapi.testclient import TestClient
|
||||
import app.utils as utils_module
|
||||
|
||||
vera_app, mock_qdrant = app_with_mocks
|
||||
|
||||
ollama_data = _ollama_chat_response("42.")
|
||||
|
||||
mock_post_resp = MagicMock()
|
||||
mock_post_resp.json.return_value = ollama_data
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_post_resp)
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
|
||||
with TestClient(vera_app) as client:
|
||||
client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "What is 6*7?"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
mock_qdrant.store_qa_turn.assert_called_once()
|
||||
args = mock_qdrant.store_qa_turn.call_args[0]
|
||||
assert "6*7" in args[0]
|
||||
assert "42." in args[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/chat (streaming)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestApiChatStreaming:
|
||||
def test_streaming_response_passthrough(self, app_with_mocks):
|
||||
"""POST /api/chat with stream=True streams Ollama chunks."""
|
||||
from fastapi.testclient import TestClient
|
||||
import app.utils as utils_module
|
||||
import app.proxy_handler as ph_module
|
||||
|
||||
vera_app, mock_qdrant = app_with_mocks
|
||||
|
||||
chunk1 = json.dumps({"message": {"content": "Hello"}, "done": False}).encode()
|
||||
chunk2 = json.dumps({"message": {"content": " world"}, "done": True}).encode()
|
||||
|
||||
async def fake_aiter_bytes():
|
||||
yield chunk1
|
||||
yield chunk2
|
||||
|
||||
mock_stream_resp = MagicMock()
|
||||
mock_stream_resp.aiter_bytes = fake_aiter_bytes
|
||||
mock_stream_resp.status_code = 200
|
||||
mock_stream_resp.headers = {"content-type": "application/x-ndjson"}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_stream_resp)
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
|
||||
with TestClient(vera_app) as client:
|
||||
resp = client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
# Response body should contain both chunks concatenated
|
||||
body_text = resp.text
|
||||
assert "Hello" in body_text or len(body_text) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health check edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHealthCheckEdgeCases:
|
||||
def test_health_ollama_timeout(self, app_with_mocks):
|
||||
"""GET / handles Ollama timeout gracefully."""
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
vera_app, _ = app_with_mocks
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.get = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
with TestClient(vera_app, raise_server_exceptions=True) as client:
|
||||
resp = client.get("/")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ollama"] == "unreachable"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /curator/run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTriggerCurator:
|
||||
def test_trigger_curator_endpoint(self, app_with_mocks):
|
||||
"""POST /curator/run triggers curation and returns status."""
|
||||
from fastapi.testclient import TestClient
|
||||
import app.main as main_module
|
||||
|
||||
vera_app, _ = app_with_mocks
|
||||
|
||||
mock_curator = MagicMock()
|
||||
mock_curator.run = AsyncMock()
|
||||
|
||||
with patch.object(main_module, "curator", mock_curator):
|
||||
with TestClient(vera_app) as client:
|
||||
resp = client.post("/curator/run")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "curation completed"
|
||||
mock_curator.run.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proxy catch-all
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProxyAll:
|
||||
def test_non_chat_api_proxied(self, app_with_mocks):
|
||||
"""Non-chat API paths are proxied to Ollama."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
vera_app, _ = app_with_mocks
|
||||
|
||||
async def fake_aiter_bytes():
|
||||
yield b'{"status": "ok"}'
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"content-type": "application/json"}
|
||||
mock_resp.aiter_bytes = fake_aiter_bytes
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||
with TestClient(vera_app) as client:
|
||||
resp = client.get("/api/show")
|
||||
|
||||
assert resp.status_code == 200
|
||||
312
tests/test_proxy_handler.py
Normal file
312
tests/test_proxy_handler.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Tests for proxy_handler — no live Ollama or Qdrant required."""
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
class TestCleanMessageContent:
|
||||
"""Tests for clean_message_content."""
|
||||
|
||||
def test_passthrough_plain_message(self):
|
||||
"""Plain text without wrapper is returned unchanged."""
|
||||
from app.proxy_handler import clean_message_content
|
||||
|
||||
content = "What is the capital of France?"
|
||||
assert clean_message_content(content) == content
|
||||
|
||||
def test_strips_memory_context_wrapper(self):
|
||||
"""[Memory context] wrapper is stripped, actual user_msg returned."""
|
||||
from app.proxy_handler import clean_message_content
|
||||
|
||||
content = (
|
||||
"[Memory context]\n"
|
||||
"some context here\n"
|
||||
"- user_msg: What is the capital of France?\n\n"
|
||||
)
|
||||
result = clean_message_content(content)
|
||||
assert result == "What is the capital of France?"
|
||||
|
||||
def test_strips_timestamp_prefix(self):
|
||||
"""ISO timestamp prefix like [2024-01-01T00:00:00] is removed."""
|
||||
from app.proxy_handler import clean_message_content
|
||||
|
||||
content = "[2024-01-01T12:34:56] Tell me a joke"
|
||||
result = clean_message_content(content)
|
||||
assert result == "Tell me a joke"
|
||||
|
||||
def test_empty_string_returned_as_is(self):
|
||||
"""Empty string input returns empty string."""
|
||||
from app.proxy_handler import clean_message_content
|
||||
|
||||
assert clean_message_content("") == ""
|
||||
|
||||
def test_none_input_returned_as_is(self):
|
||||
"""None/falsy input is returned unchanged."""
|
||||
from app.proxy_handler import clean_message_content
|
||||
|
||||
assert clean_message_content(None) is None
|
||||
|
||||
def test_list_content_raises_type_error(self):
|
||||
"""Non-string content (list) causes TypeError — the function expects strings."""
|
||||
import pytest
|
||||
from app.proxy_handler import clean_message_content
|
||||
|
||||
# The function passes lists to re.search which requires str/bytes.
|
||||
# Document this behavior so we know it's a known limitation.
|
||||
content = [{"type": "text", "text": "hello"}]
|
||||
with pytest.raises(TypeError):
|
||||
clean_message_content(content)
|
||||
|
||||
|
||||
class TestHandleChatNonStreaming:
|
||||
"""Tests for handle_chat_non_streaming — fully mocked external I/O."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_json_response(self):
|
||||
"""Should return a JSONResponse with Ollama result merged with model field."""
|
||||
from app.proxy_handler import handle_chat_non_streaming
|
||||
|
||||
ollama_resp_data = {
|
||||
"message": {"role": "assistant", "content": "Paris."},
|
||||
"done": True,
|
||||
}
|
||||
|
||||
mock_httpx_resp = MagicMock()
|
||||
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||
|
||||
augmented = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
|
||||
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client):
|
||||
|
||||
body = {
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||
"stream": False,
|
||||
}
|
||||
response = await handle_chat_non_streaming(body)
|
||||
|
||||
# FastAPI JSONResponse
|
||||
from fastapi.responses import JSONResponse
|
||||
assert isinstance(response, JSONResponse)
|
||||
response_body = json.loads(response.body)
|
||||
assert response_body["message"]["content"] == "Paris."
|
||||
assert response_body["model"] == "llama3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_qa_turn_when_answer_present(self):
|
||||
"""store_qa_turn should be called with user question and assistant answer."""
|
||||
from app.proxy_handler import handle_chat_non_streaming
|
||||
|
||||
ollama_resp_data = {
|
||||
"message": {"role": "assistant", "content": "Berlin."},
|
||||
"done": True,
|
||||
}
|
||||
|
||||
mock_httpx_resp = MagicMock()
|
||||
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||
|
||||
augmented = [{"role": "user", "content": "Capital of Germany?"}]
|
||||
|
||||
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client):
|
||||
|
||||
body = {
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "Capital of Germany?"}],
|
||||
"stream": False,
|
||||
}
|
||||
await handle_chat_non_streaming(body)
|
||||
|
||||
mock_qdrant.store_qa_turn.assert_called_once()
|
||||
call_args = mock_qdrant.store_qa_turn.call_args
|
||||
assert "Capital of Germany?" in call_args[0][0]
|
||||
assert "Berlin." in call_args[0][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_store_when_empty_answer(self):
|
||||
"""store_qa_turn should NOT be called when the assistant answer is empty."""
|
||||
from app.proxy_handler import handle_chat_non_streaming
|
||||
|
||||
ollama_resp_data = {
|
||||
"message": {"role": "assistant", "content": ""},
|
||||
"done": True,
|
||||
}
|
||||
|
||||
mock_httpx_resp = MagicMock()
|
||||
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||
|
||||
augmented = [{"role": "user", "content": "Hello?"}]
|
||||
|
||||
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client):
|
||||
|
||||
body = {
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "Hello?"}],
|
||||
"stream": False,
|
||||
}
|
||||
await handle_chat_non_streaming(body)
|
||||
|
||||
mock_qdrant.store_qa_turn.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleans_memory_context_from_user_message(self):
|
||||
"""User message with [Memory context] wrapper should be cleaned before storing."""
|
||||
from app.proxy_handler import handle_chat_non_streaming
|
||||
|
||||
ollama_resp_data = {
|
||||
"message": {"role": "assistant", "content": "42."},
|
||||
"done": True,
|
||||
}
|
||||
|
||||
mock_httpx_resp = MagicMock()
|
||||
mock_httpx_resp.json.return_value = ollama_resp_data
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_httpx_resp)
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
mock_qdrant.store_qa_turn = AsyncMock(return_value="fake-uuid")
|
||||
|
||||
raw_content = (
|
||||
"[Memory context]\nsome ctx\n- user_msg: What is the answer?\n\n"
|
||||
)
|
||||
augmented = [{"role": "user", "content": "What is the answer?"}]
|
||||
|
||||
with patch("app.proxy_handler.build_augmented_messages", AsyncMock(return_value=augmented)), \
|
||||
patch("app.proxy_handler.get_qdrant_service", return_value=mock_qdrant), \
|
||||
patch("httpx.AsyncClient", return_value=mock_client):
|
||||
|
||||
body = {
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": raw_content}],
|
||||
"stream": False,
|
||||
}
|
||||
await handle_chat_non_streaming(body)
|
||||
|
||||
call_args = mock_qdrant.store_qa_turn.call_args
|
||||
stored_question = call_args[0][0]
|
||||
# The wrapper should be stripped
|
||||
assert "Memory context" not in stored_question
|
||||
assert "What is the answer?" in stored_question
|
||||
|
||||
|
||||
class TestDebugLog:
|
||||
"""Tests for debug_log function."""
|
||||
|
||||
def test_debug_log_writes_json_when_enabled(self, tmp_path):
|
||||
"""Debug log appends valid JSON line to file when debug=True."""
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.debug = True
|
||||
|
||||
with patch("app.proxy_handler.config", mock_config), \
|
||||
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
|
||||
from app.proxy_handler import debug_log
|
||||
debug_log("test_cat", "test message", {"key": "value"})
|
||||
|
||||
log_files = list(tmp_path.glob("debug_*.log"))
|
||||
assert len(log_files) == 1
|
||||
content = log_files[0].read_text().strip()
|
||||
entry = json.loads(content)
|
||||
assert entry["category"] == "test_cat"
|
||||
assert entry["message"] == "test message"
|
||||
assert entry["data"]["key"] == "value"
|
||||
|
||||
def test_debug_log_skips_when_disabled(self, tmp_path):
|
||||
"""Debug log does nothing when debug=False."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.debug = False
|
||||
|
||||
with patch("app.proxy_handler.config", mock_config), \
|
||||
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
|
||||
from app.proxy_handler import debug_log
|
||||
debug_log("test_cat", "test message")
|
||||
|
||||
log_files = list(tmp_path.glob("debug_*.log"))
|
||||
assert len(log_files) == 0
|
||||
|
||||
def test_debug_log_without_data(self, tmp_path):
|
||||
"""Debug log works without optional data parameter."""
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.debug = True
|
||||
|
||||
with patch("app.proxy_handler.config", mock_config), \
|
||||
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
|
||||
from app.proxy_handler import debug_log
|
||||
debug_log("simple_cat", "no data here")
|
||||
|
||||
log_files = list(tmp_path.glob("debug_*.log"))
|
||||
assert len(log_files) == 1
|
||||
entry = json.loads(log_files[0].read_text().strip())
|
||||
assert "data" not in entry
|
||||
assert entry["category"] == "simple_cat"
|
||||
|
||||
|
||||
class TestForwardToOllama:
|
||||
"""Tests for forward_to_ollama function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forwards_request_to_ollama(self):
|
||||
"""forward_to_ollama proxies request to Ollama host."""
|
||||
from app.proxy_handler import forward_to_ollama
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.body = AsyncMock(return_value=b'{"model": "llama3"}')
|
||||
mock_request.method = "POST"
|
||||
mock_request.headers = {"content-type": "application/json", "content-length": "20"}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.request = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
result = await forward_to_ollama(mock_request, "/api/show")
|
||||
|
||||
assert result == mock_resp
|
||||
mock_client.request.assert_called_once()
|
||||
call_kwargs = mock_client.request.call_args
|
||||
assert call_kwargs[1]["method"] == "POST"
|
||||
assert "/api/show" in call_kwargs[1]["url"]
|
||||
256
tests/test_qdrant_service.py
Normal file
256
tests/test_qdrant_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Tests for QdrantService — all Qdrant and Ollama calls are mocked."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def make_qdrant_service():
|
||||
"""Create a QdrantService with mocked AsyncQdrantClient."""
|
||||
with patch("app.qdrant_service.AsyncQdrantClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
from app.qdrant_service import QdrantService
|
||||
svc = QdrantService(
|
||||
host="http://localhost:6333",
|
||||
collection="test_memories",
|
||||
embedding_model="snowflake-arctic-embed2",
|
||||
vector_size=1024,
|
||||
ollama_host="http://localhost:11434",
|
||||
)
|
||||
|
||||
return svc, mock_client
|
||||
|
||||
|
||||
class TestEnsureCollection:
|
||||
"""Tests for _ensure_collection."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_collection_when_missing(self):
|
||||
"""Creates collection if it does not exist."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
mock_client.get_collection = AsyncMock(side_effect=Exception("not found"))
|
||||
mock_client.create_collection = AsyncMock()
|
||||
|
||||
await svc._ensure_collection()
|
||||
|
||||
mock_client.create_collection.assert_called_once()
|
||||
assert svc._collection_ensured is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_if_collection_exists(self):
|
||||
"""Does not create collection if it already exists."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
mock_client.get_collection = AsyncMock(return_value=MagicMock())
|
||||
mock_client.create_collection = AsyncMock()
|
||||
|
||||
await svc._ensure_collection()
|
||||
|
||||
mock_client.create_collection.assert_not_called()
|
||||
assert svc._collection_ensured is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_if_already_ensured(self):
|
||||
"""Skips entirely if _collection_ensured is True."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
mock_client.get_collection = AsyncMock()
|
||||
|
||||
await svc._ensure_collection()
|
||||
|
||||
mock_client.get_collection.assert_not_called()
|
||||
|
||||
|
||||
class TestGetEmbedding:
|
||||
"""Tests for get_embedding."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_embedding_vector(self):
|
||||
"""Returns embedding from Ollama response."""
|
||||
svc, _ = make_qdrant_service()
|
||||
fake_embedding = [0.1] * 1024
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"embedding": fake_embedding}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
result = await svc.get_embedding("test text")
|
||||
|
||||
assert result == fake_embedding
|
||||
assert len(result) == 1024
|
||||
|
||||
|
||||
class TestStoreTurn:
|
||||
"""Tests for store_turn."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_raw_user_turn(self):
|
||||
"""Stores a user turn with proper payload."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
mock_client.upsert = AsyncMock()
|
||||
|
||||
fake_embedding = [0.1] * 1024
|
||||
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||
point_id = await svc.store_turn(role="user", content="hello world")
|
||||
|
||||
assert isinstance(point_id, str)
|
||||
mock_client.upsert.assert_called_once()
|
||||
call_args = mock_client.upsert.call_args
|
||||
point = call_args[1]["points"][0]
|
||||
assert point.payload["type"] == "raw"
|
||||
assert point.payload["role"] == "user"
|
||||
assert "User: hello world" in point.payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_curated_turn(self):
|
||||
"""Stores a curated turn without role prefix in text."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
mock_client.upsert = AsyncMock()
|
||||
|
||||
fake_embedding = [0.1] * 1024
|
||||
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||
point_id = await svc.store_turn(
|
||||
role="curated",
|
||||
content="User: q\nAssistant: a",
|
||||
entry_type="curated"
|
||||
)
|
||||
|
||||
call_args = mock_client.upsert.call_args
|
||||
point = call_args[1]["points"][0]
|
||||
assert point.payload["type"] == "curated"
|
||||
# Curated text should be the content directly, not prefixed
|
||||
assert point.payload["text"] == "User: q\nAssistant: a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_with_topic_and_metadata(self):
|
||||
"""Stores turn with optional topic and metadata."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
mock_client.upsert = AsyncMock()
|
||||
|
||||
fake_embedding = [0.1] * 1024
|
||||
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||
await svc.store_turn(
|
||||
role="assistant",
|
||||
content="some response",
|
||||
topic="python",
|
||||
metadata={"source": "test"}
|
||||
)
|
||||
|
||||
call_args = mock_client.upsert.call_args
|
||||
point = call_args[1]["points"][0]
|
||||
assert point.payload["topic"] == "python"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
|
||||
class TestStoreQaTurn:
|
||||
"""Tests for store_qa_turn."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_qa_turn(self):
|
||||
"""Stores a complete Q&A turn."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
mock_client.upsert = AsyncMock()
|
||||
|
||||
fake_embedding = [0.1] * 1024
|
||||
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||
point_id = await svc.store_qa_turn("What is Python?", "A programming language.")
|
||||
|
||||
assert isinstance(point_id, str)
|
||||
call_args = mock_client.upsert.call_args
|
||||
point = call_args[1]["points"][0]
|
||||
assert point.payload["type"] == "raw"
|
||||
assert point.payload["role"] == "qa"
|
||||
assert "What is Python?" in point.payload["text"]
|
||||
assert "A programming language." in point.payload["text"]
|
||||
|
||||
|
||||
class TestSemanticSearch:
|
||||
"""Tests for semantic_search."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_matching_results(self):
|
||||
"""Returns formatted search results."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "result-1"
|
||||
mock_point.score = 0.85
|
||||
mock_point.payload = {"text": "User: hello\nAssistant: hi", "type": "curated"}
|
||||
|
||||
mock_query_result = MagicMock()
|
||||
mock_query_result.points = [mock_point]
|
||||
mock_client.query_points = AsyncMock(return_value=mock_query_result)
|
||||
|
||||
fake_embedding = [0.1] * 1024
|
||||
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||
results = await svc.semantic_search("hello", limit=10, score_threshold=0.6)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "result-1"
|
||||
assert results[0]["score"] == 0.85
|
||||
assert results[0]["payload"]["type"] == "curated"
|
||||
|
||||
|
||||
class TestGetRecentTurns:
|
||||
"""Tests for get_recent_turns."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_sorted_recent_turns(self):
|
||||
"""Returns turns sorted by timestamp descending."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
svc._collection_ensured = True
|
||||
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.id = "old"
|
||||
mock_point1.payload = {"timestamp": "2026-01-01T00:00:00Z", "text": "old turn"}
|
||||
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.id = "new"
|
||||
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
|
||||
|
||||
# OrderBy returns server-sorted results (newest first)
|
||||
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
|
||||
|
||||
results = await svc.get_recent_turns(limit=2)
|
||||
|
||||
assert len(results) == 2
|
||||
# Newest first (server-sorted via OrderBy)
|
||||
assert results[0]["id"] == "new"
|
||||
assert results[1]["id"] == "old"
|
||||
|
||||
|
||||
class TestDeletePoints:
|
||||
"""Tests for delete_points."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_by_ids(self):
|
||||
"""Deletes points by their IDs."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
mock_client.delete = AsyncMock()
|
||||
|
||||
await svc.delete_points(["id1", "id2"])
|
||||
|
||||
mock_client.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestClose:
|
||||
"""Tests for close."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closes_client(self):
|
||||
"""Closes the async Qdrant client."""
|
||||
svc, mock_client = make_qdrant_service()
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
await svc.close()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
437
tests/test_utils.py
Normal file
437
tests/test_utils.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Tests for utility functions."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn, build_augmented_messages, count_messages_tokens
|
||||
|
||||
|
||||
class TestCountTokens:
|
||||
"""Tests for count_tokens function."""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string should return 0 tokens."""
|
||||
assert count_tokens("") == 0
|
||||
|
||||
def test_simple_text(self):
|
||||
"""Simple text should count tokens correctly."""
|
||||
text = "Hello, world!"
|
||||
assert count_tokens(text) > 0
|
||||
|
||||
def test_longer_text(self):
|
||||
"""Longer text should have more tokens."""
|
||||
short = "Hello"
|
||||
long = "Hello, this is a longer sentence with more words."
|
||||
assert count_tokens(long) > count_tokens(short)
|
||||
|
||||
|
||||
class TestTruncateByTokens:
|
||||
"""Tests for truncate_by_tokens function."""
|
||||
|
||||
def test_no_truncation_needed(self):
|
||||
"""Text shorter than limit should not be truncated."""
|
||||
text = "Short text"
|
||||
result = truncate_by_tokens(text, max_tokens=100)
|
||||
assert result == text
|
||||
|
||||
def test_truncation_applied(self):
|
||||
"""Text longer than limit should be truncated."""
|
||||
text = "This is a longer piece of text that will need to be truncated"
|
||||
result = truncate_by_tokens(text, max_tokens=5)
|
||||
assert count_tokens(result) <= 5
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string should return empty string."""
|
||||
assert truncate_by_tokens("", max_tokens=10) == ""
|
||||
|
||||
|
||||
class TestParseCuratedTurn:
|
||||
"""Tests for parse_curated_turn function."""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string should return empty list."""
|
||||
assert parse_curated_turn("") == []
|
||||
|
||||
def test_single_turn(self):
|
||||
"""Single Q&A turn should parse correctly."""
|
||||
text = "User: What is Python?\nAssistant: A programming language."
|
||||
result = parse_curated_turn(text)
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "What is Python?"
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["content"] == "A programming language."
|
||||
|
||||
def test_multiple_turns(self):
|
||||
"""Multiple Q&A turns should parse correctly."""
|
||||
text = """User: What is Python?
|
||||
Assistant: A programming language.
|
||||
User: Is it popular?
|
||||
Assistant: Yes, very popular."""
|
||||
result = parse_curated_turn(text)
|
||||
assert len(result) == 4
|
||||
|
||||
def test_timestamp_ignored(self):
|
||||
"""Timestamp lines should be ignored."""
|
||||
text = "User: Question?\nAssistant: Answer.\nTimestamp: 2024-01-01T00:00:00Z"
|
||||
result = parse_curated_turn(text)
|
||||
assert len(result) == 2
|
||||
for msg in result:
|
||||
assert "Timestamp" not in msg["content"]
|
||||
|
||||
def test_multiline_content(self):
|
||||
"""Multiline content should be preserved."""
|
||||
text = "User: Line 1\nLine 2\nLine 3\nAssistant: Response"
|
||||
result = parse_curated_turn(text)
|
||||
assert "Line 1" in result[0]["content"]
|
||||
assert "Line 2" in result[0]["content"]
|
||||
assert "Line 3" in result[0]["content"]
|
||||
|
||||
|
||||
class TestCountMessagesTokens:
|
||||
"""Tests for count_messages_tokens function."""
|
||||
|
||||
def test_empty_list(self):
|
||||
"""Empty message list returns 0."""
|
||||
assert count_messages_tokens([]) == 0
|
||||
|
||||
def test_single_message(self):
|
||||
"""Single message counts tokens of its content."""
|
||||
msgs = [{"role": "user", "content": "Hello world"}]
|
||||
result = count_messages_tokens(msgs)
|
||||
assert result > 0
|
||||
|
||||
def test_multiple_messages(self):
|
||||
"""Multiple messages sum up their token counts."""
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there, how can I help you today?"},
|
||||
]
|
||||
result = count_messages_tokens(msgs)
|
||||
assert result > count_messages_tokens([msgs[0]])
|
||||
|
||||
def test_message_without_content(self):
|
||||
"""Message without content field contributes 0 tokens."""
|
||||
msgs = [{"role": "system"}]
|
||||
assert count_messages_tokens(msgs) == 0
|
||||
|
||||
|
||||
class TestLoadSystemPrompt:
|
||||
"""Tests for load_system_prompt function."""
|
||||
|
||||
def test_loads_from_prompts_dir(self, tmp_path):
|
||||
"""Loads systemprompt.md from PROMPTS_DIR."""
|
||||
import app.utils as utils_module
|
||||
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
prompts_dir.mkdir()
|
||||
(prompts_dir / "systemprompt.md").write_text("You are Vera.")
|
||||
|
||||
with patch.object(utils_module, "PROMPTS_DIR", prompts_dir):
|
||||
result = utils_module.load_system_prompt()
|
||||
|
||||
assert result == "You are Vera."
|
||||
|
||||
def test_falls_back_to_static_dir(self, tmp_path):
|
||||
"""Falls back to STATIC_DIR when PROMPTS_DIR has no file."""
|
||||
import app.utils as utils_module
|
||||
|
||||
prompts_dir = tmp_path / "no_prompts" # does not exist
|
||||
static_dir = tmp_path / "static"
|
||||
static_dir.mkdir()
|
||||
(static_dir / "systemprompt.md").write_text("Static Vera.")
|
||||
|
||||
with patch.object(utils_module, "PROMPTS_DIR", prompts_dir), \
|
||||
patch.object(utils_module, "STATIC_DIR", static_dir):
|
||||
result = utils_module.load_system_prompt()
|
||||
|
||||
assert result == "Static Vera."
|
||||
|
||||
def test_returns_empty_when_not_found(self, tmp_path):
|
||||
"""Returns empty string when systemprompt.md not found anywhere."""
|
||||
import app.utils as utils_module
|
||||
|
||||
with patch.object(utils_module, "PROMPTS_DIR", tmp_path / "nope"), \
|
||||
patch.object(utils_module, "STATIC_DIR", tmp_path / "also_nope"):
|
||||
result = utils_module.load_system_prompt()
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestFilterMemoriesByTime:
|
||||
"""Tests for filter_memories_by_time function."""
|
||||
|
||||
def test_includes_recent_memory(self):
|
||||
"""Memory with timestamp in the last 24h should be included."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)).isoformat()
|
||||
memories = [{"timestamp": ts, "text": "recent"}]
|
||||
result = filter_memories_by_time(memories, hours=24)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_excludes_old_memory(self):
|
||||
"""Memory older than cutoff should be excluded."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=48)).isoformat()
|
||||
memories = [{"timestamp": ts, "text": "old"}]
|
||||
result = filter_memories_by_time(memories, hours=24)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_includes_memory_without_timestamp(self):
|
||||
"""Memory with no timestamp should always be included."""
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
memories = [{"text": "no ts"}]
|
||||
result = filter_memories_by_time(memories, hours=24)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_includes_memory_with_bad_timestamp(self):
|
||||
"""Memory with unparseable timestamp should be included (safe default)."""
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
memories = [{"timestamp": "not-a-date", "text": "bad ts"}]
|
||||
result = filter_memories_by_time(memories, hours=24)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_z_suffix_old_timestamp_excluded(self):
|
||||
"""Regression: chained .replace() was not properly handling Z suffix on old timestamps."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
old_ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=48)).isoformat() + "Z"
|
||||
memories = [{"timestamp": old_ts, "text": "old with Z"}]
|
||||
result = filter_memories_by_time(memories, hours=24)
|
||||
assert len(result) == 0, f"Old Z-suffixed timestamp should be excluded but wasn't: {old_ts}"
|
||||
|
||||
def test_empty_list(self):
|
||||
"""Empty input returns empty list."""
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
assert filter_memories_by_time([], hours=24) == []
|
||||
|
||||
def test_z_suffix_timestamp(self):
|
||||
"""ISO timestamp with Z suffix should be handled correctly."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.utils import filter_memories_by_time
|
||||
|
||||
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)).isoformat() + "Z"
|
||||
memories = [{"timestamp": ts, "text": "recent with Z"}]
|
||||
result = filter_memories_by_time(memories, hours=24)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestMergeMemories:
|
||||
"""Tests for merge_memories function."""
|
||||
|
||||
def test_empty_list(self):
|
||||
"""Empty list returns empty text and ids."""
|
||||
from app.utils import merge_memories
|
||||
|
||||
result = merge_memories([])
|
||||
assert result == {"text": "", "ids": []}
|
||||
|
||||
def test_single_memory_with_text(self):
|
||||
"""Single memory with text field is merged."""
|
||||
from app.utils import merge_memories
|
||||
|
||||
memories = [{"id": "abc", "text": "hello world", "role": ""}]
|
||||
result = merge_memories(memories)
|
||||
assert "hello world" in result["text"]
|
||||
assert "abc" in result["ids"]
|
||||
|
||||
def test_memory_with_content_field(self):
|
||||
"""Memory using content field (no text) is merged."""
|
||||
from app.utils import merge_memories
|
||||
|
||||
memories = [{"id": "xyz", "content": "from content field"}]
|
||||
result = merge_memories(memories)
|
||||
assert "from content field" in result["text"]
|
||||
|
||||
def test_role_included_in_output(self):
|
||||
"""Role prefix should appear in merged text when role is set."""
|
||||
from app.utils import merge_memories
|
||||
|
||||
memories = [{"id": "1", "text": "question", "role": "user"}]
|
||||
result = merge_memories(memories)
|
||||
assert "[user]:" in result["text"]
|
||||
|
||||
def test_multiple_memories_joined(self):
|
||||
"""Multiple memories are joined with double newline."""
|
||||
from app.utils import merge_memories
|
||||
|
||||
memories = [
|
||||
{"id": "1", "text": "first"},
|
||||
{"id": "2", "text": "second"},
|
||||
]
|
||||
result = merge_memories(memories)
|
||||
assert "first" in result["text"]
|
||||
assert "second" in result["text"]
|
||||
assert len(result["ids"]) == 2
|
||||
|
||||
|
||||
class TestBuildAugmentedMessages:
|
||||
"""Tests for build_augmented_messages function (mocked I/O)."""
|
||||
|
||||
def _make_qdrant_mock(self):
|
||||
"""Return an AsyncMock QdrantService."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
mock_qdrant.semantic_search = AsyncMock(return_value=[])
|
||||
mock_qdrant.get_recent_turns = AsyncMock(return_value=[])
|
||||
return mock_qdrant
|
||||
|
||||
def test_system_layer_prepended(self, monkeypatch, tmp_path):
|
||||
"""System prompt from file should be prepended to messages."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
import app.utils as utils_module
|
||||
|
||||
# Write a temp system prompt
|
||||
prompt_file = tmp_path / "systemprompt.md"
|
||||
prompt_file.write_text("You are Vera.")
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value="You are Vera."), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
utils_module.build_augmented_messages(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
)
|
||||
|
||||
system_msgs = [m for m in result if m["role"] == "system"]
|
||||
assert len(system_msgs) == 1
|
||||
assert "You are Vera." in system_msgs[0]["content"]
|
||||
|
||||
def test_incoming_user_message_preserved(self, monkeypatch):
|
||||
"""Incoming user message should appear in output."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
utils_module.build_augmented_messages(
|
||||
[{"role": "user", "content": "What is 2+2?"}]
|
||||
)
|
||||
)
|
||||
|
||||
user_msgs = [m for m in result if m.get("role") == "user"]
|
||||
assert any("2+2" in m["content"] for m in user_msgs)
|
||||
|
||||
def test_no_system_message_when_no_prompt(self, monkeypatch):
|
||||
"""No system message added when both incoming and file prompt are empty."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
utils_module.build_augmented_messages(
|
||||
[{"role": "user", "content": "Hi"}]
|
||||
)
|
||||
)
|
||||
|
||||
system_msgs = [m for m in result if m.get("role") == "system"]
|
||||
assert len(system_msgs) == 0
|
||||
|
||||
def test_semantic_results_injected(self, monkeypatch):
|
||||
"""Curated memories from semantic search should appear in output."""
|
||||
import asyncio
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = MagicMock()
|
||||
mock_qdrant.semantic_search = AsyncMock(return_value=[
|
||||
{"payload": {"text": "User: Old question?\nAssistant: Old answer."}}
|
||||
])
|
||||
mock_qdrant.get_recent_turns = AsyncMock(return_value=[])
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
utils_module.build_augmented_messages(
|
||||
[{"role": "user", "content": "Tell me"}]
|
||||
)
|
||||
)
|
||||
|
||||
contents = [m["content"] for m in result]
|
||||
assert any("Old question" in c or "Old answer" in c for c in contents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_prompt_appends_to_caller_system(self):
|
||||
"""systemprompt.md content appends to caller's system message."""
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value="Vera memory context"), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
incoming = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
result = await build_augmented_messages(incoming)
|
||||
|
||||
system_msg = result[0]
|
||||
assert system_msg["role"] == "system"
|
||||
assert system_msg["content"] == "You are a helpful assistant.\n\nVera memory context"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_system_prompt_passthrough(self):
|
||||
"""When systemprompt.md is empty, only caller's system message passes through."""
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
incoming = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
result = await build_augmented_messages(incoming)
|
||||
|
||||
system_msg = result[0]
|
||||
assert system_msg["content"] == "You are a helpful assistant."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_caller_system_with_vera_prompt(self):
|
||||
"""When caller sends no system message but systemprompt.md exists, use vera prompt."""
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value="Vera memory context"), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
incoming = [{"role": "user", "content": "Hello"}]
|
||||
result = await build_augmented_messages(incoming)
|
||||
|
||||
system_msg = result[0]
|
||||
assert system_msg["role"] == "system"
|
||||
assert system_msg["content"] == "Vera memory context"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_system_anywhere(self):
|
||||
"""When neither caller nor systemprompt.md provides system content, no system message."""
|
||||
import app.utils as utils_module
|
||||
|
||||
mock_qdrant = self._make_qdrant_mock()
|
||||
|
||||
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||
incoming = [{"role": "user", "content": "Hello"}]
|
||||
result = await build_augmented_messages(incoming)
|
||||
|
||||
# First message should be user, not system
|
||||
assert result[0]["role"] == "user"
|
||||
Reference in New Issue
Block a user