Compare commits
9 Commits
9774875173
...
v2.0.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
346f2c26fe | ||
|
|
de7f3a78ab | ||
|
|
6154e7974e | ||
|
|
cbe12f0ebd | ||
|
|
9fa5d08ce0 | ||
|
|
90dd87edeb | ||
|
|
2801a63b11 | ||
|
|
355986a59f | ||
|
|
600f9deec1 |
@@ -1,5 +1,5 @@
|
|||||||
# app/config.py
|
# app/config.py
|
||||||
import toml
|
import tomllib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -83,8 +83,8 @@ class Config:
|
|||||||
config = cls()
|
config = cls()
|
||||||
|
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
with open(config_path, "r") as f:
|
with open(config_path, "rb") as f:
|
||||||
data = toml.load(f)
|
data = tomllib.load(f)
|
||||||
|
|
||||||
if "general" in data:
|
if "general" in data:
|
||||||
config.ollama_host = data["general"].get("ollama_host", config.ollama_host)
|
config.ollama_host = data["general"].get("ollama_host", config.ollama_host)
|
||||||
@@ -112,7 +112,14 @@ class Config:
|
|||||||
api_key_env=cloud_data.get("api_key_env", "OPENROUTER_API_KEY"),
|
api_key_env=cloud_data.get("api_key_env", "OPENROUTER_API_KEY"),
|
||||||
models=cloud_data.get("models", {})
|
models=cloud_data.get("models", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.cloud.enabled and not config.cloud.api_key:
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"Cloud is enabled but API key env var '%s' is not set",
|
||||||
|
config.cloud.api_key_env
|
||||||
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
config = Config.load()
|
config = Config.load()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ The prompt determines behavior based on current date.
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import httpx
|
import httpx
|
||||||
@@ -49,7 +49,7 @@ class Curator:
|
|||||||
Otherwise runs daily mode (processes recent 24h only).
|
Otherwise runs daily mode (processes recent 24h only).
|
||||||
The prompt determines behavior based on current date.
|
The prompt determines behavior based on current date.
|
||||||
"""
|
"""
|
||||||
current_date = datetime.utcnow()
|
current_date = datetime.now(timezone.utc)
|
||||||
is_monthly = current_date.day == 1
|
is_monthly = current_date.day == 1
|
||||||
mode = "MONTHLY" if is_monthly else "DAILY"
|
mode = "MONTHLY" if is_monthly else "DAILY"
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
|
|||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
mem_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
mem_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=hours)
|
||||||
return mem_time.replace(tzinfo=None) > cutoff
|
return mem_time.replace(tzinfo=None) > cutoff
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
logger.debug(f"Could not parse timestamp: {timestamp}")
|
logger.debug(f"Could not parse timestamp: {timestamp}")
|
||||||
@@ -212,7 +212,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
return result.get("response", "")
|
return result.get("response", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to call LLM: {e}")
|
logger.error(f"LLM call failed: {e}", exc_info=True)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _parse_json_response(self, response: str) -> Optional[Dict]:
|
def _parse_json_response(self, response: str) -> Optional[Dict]:
|
||||||
@@ -223,6 +223,7 @@ Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the J
|
|||||||
try:
|
try:
|
||||||
return json.loads(response)
|
return json.loads(response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Direct JSON parse failed, trying brace extraction")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import httpx
|
import httpx
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .singleton import get_qdrant_service
|
from .singleton import get_qdrant_service
|
||||||
@@ -68,7 +68,7 @@ async def lifespan(app: FastAPI):
|
|||||||
await qdrant_service.close()
|
await qdrant_service.close()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan)
|
app = FastAPI(title="Vera-AI", version="2.0.4", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@@ -96,7 +96,7 @@ async def api_tags():
|
|||||||
for name in config.cloud.models.keys():
|
for name in config.cloud.models.keys():
|
||||||
data["models"].append({
|
data["models"].append({
|
||||||
"name": name,
|
"name": name,
|
||||||
"modified_at": "2026-03-25T00:00:00Z",
|
"modified_at": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||||
"size": 0,
|
"size": 0,
|
||||||
"digest": "cloud",
|
"digest": "cloud",
|
||||||
"details": {"family": "cloud"}
|
"details": {"family": "cloud"}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import portalocker
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .config import config
|
from .config import config
|
||||||
from .singleton import get_qdrant_service
|
from .singleton import get_qdrant_service
|
||||||
@@ -48,17 +49,17 @@ def debug_log(category: str, message: str, data: dict = None):
|
|||||||
if not config.debug:
|
if not config.debug:
|
||||||
return
|
return
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
# Create logs directory
|
# Create logs directory
|
||||||
log_dir = DEBUG_LOG_DIR
|
log_dir = DEBUG_LOG_DIR
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
log_path = log_dir / f"debug_{today}.log"
|
log_path = log_dir / f"debug_{today}.log"
|
||||||
|
|
||||||
entry = {
|
entry = {
|
||||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||||
"category": category,
|
"category": category,
|
||||||
"message": message
|
"message": message
|
||||||
}
|
}
|
||||||
@@ -66,7 +67,9 @@ def debug_log(category: str, message: str, data: dict = None):
|
|||||||
entry["data"] = data
|
entry["data"] = data
|
||||||
|
|
||||||
with open(log_path, "a") as f:
|
with open(log_path, "a") as f:
|
||||||
|
portalocker.lock(f, portalocker.LOCK_EX)
|
||||||
f.write(json.dumps(entry) + "\n")
|
f.write(json.dumps(entry) + "\n")
|
||||||
|
portalocker.unlock(f)
|
||||||
|
|
||||||
|
|
||||||
async def handle_chat_non_streaming(body: dict):
|
async def handle_chat_non_streaming(body: dict):
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Qdrant service for memory storage - ASYNC VERSION."""
|
"""Qdrant service for memory storage - ASYNC VERSION."""
|
||||||
from qdrant_client import AsyncQdrantClient
|
from qdrant_client import AsyncQdrantClient
|
||||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
@@ -34,6 +34,15 @@ class QdrantService:
|
|||||||
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
|
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
|
||||||
)
|
)
|
||||||
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
|
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
|
||||||
|
# Ensure payload index on timestamp for ordered scroll
|
||||||
|
try:
|
||||||
|
await self.client.create_payload_index(
|
||||||
|
collection_name=self.collection,
|
||||||
|
field_name="timestamp",
|
||||||
|
field_schema=PayloadSchemaType.KEYWORD
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Index may already exist
|
||||||
self._collection_ensured = True
|
self._collection_ensured = True
|
||||||
|
|
||||||
async def get_embedding(self, text: str) -> List[float]:
|
async def get_embedding(self, text: str) -> List[float]:
|
||||||
@@ -54,7 +63,7 @@ class QdrantService:
|
|||||||
point_id = str(uuid.uuid4())
|
point_id = str(uuid.uuid4())
|
||||||
embedding = await self.get_embedding(content)
|
embedding = await self.get_embedding(content)
|
||||||
|
|
||||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||||
text = content
|
text = content
|
||||||
if role == "user":
|
if role == "user":
|
||||||
text = f"User: {content}"
|
text = f"User: {content}"
|
||||||
@@ -85,7 +94,7 @@ class QdrantService:
|
|||||||
"""Store a complete Q&A turn as one document."""
|
"""Store a complete Q&A turn as one document."""
|
||||||
await self._ensure_collection()
|
await self._ensure_collection()
|
||||||
|
|
||||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||||
text = f"User: {user_question}\nAssistant: {assistant_answer}\nTimestamp: {timestamp}"
|
text = f"User: {user_question}\nAssistant: {assistant_answer}\nTimestamp: {timestamp}"
|
||||||
|
|
||||||
point_id = str(uuid.uuid4())
|
point_id = str(uuid.uuid4())
|
||||||
@@ -105,20 +114,28 @@ class QdrantService:
|
|||||||
)
|
)
|
||||||
return point_id
|
return point_id
|
||||||
|
|
||||||
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]:
|
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated", entry_types: Optional[List[str]] = None) -> List[Dict]:
|
||||||
"""Semantic search for relevant turns, filtered by type."""
|
"""Semantic search for relevant turns, filtered by type(s)."""
|
||||||
await self._ensure_collection()
|
await self._ensure_collection()
|
||||||
|
|
||||||
embedding = await self.get_embedding(query)
|
embedding = await self.get_embedding(query)
|
||||||
|
|
||||||
|
if entry_types and len(entry_types) > 1:
|
||||||
|
type_filter = Filter(
|
||||||
|
should=[FieldCondition(key="type", match=MatchValue(value=t)) for t in entry_types]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filter_type = entry_types[0] if entry_types else entry_type
|
||||||
|
type_filter = Filter(
|
||||||
|
must=[FieldCondition(key="type", match=MatchValue(value=filter_type))]
|
||||||
|
)
|
||||||
|
|
||||||
results = await self.client.query_points(
|
results = await self.client.query_points(
|
||||||
collection_name=self.collection,
|
collection_name=self.collection,
|
||||||
query=embedding,
|
query=embedding,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
query_filter=Filter(
|
query_filter=type_filter
|
||||||
must=[FieldCondition(key="type", match=MatchValue(value=entry_type))]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
|
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
|
||||||
@@ -126,21 +143,29 @@ class QdrantService:
|
|||||||
async def get_recent_turns(self, limit: int = 20) -> List[Dict]:
|
async def get_recent_turns(self, limit: int = 20) -> List[Dict]:
|
||||||
"""Get recent turns from Qdrant (both raw and curated)."""
|
"""Get recent turns from Qdrant (both raw and curated)."""
|
||||||
await self._ensure_collection()
|
await self._ensure_collection()
|
||||||
|
|
||||||
points, _ = await self.client.scroll(
|
try:
|
||||||
collection_name=self.collection,
|
from qdrant_client.models import OrderBy
|
||||||
limit=limit * 2,
|
points, _ = await self.client.scroll(
|
||||||
with_payload=True
|
collection_name=self.collection,
|
||||||
)
|
limit=limit,
|
||||||
|
with_payload=True,
|
||||||
# Sort by timestamp descending
|
order_by=OrderBy(key="timestamp", direction="desc")
|
||||||
sorted_points = sorted(
|
)
|
||||||
points,
|
except Exception:
|
||||||
key=lambda p: p.payload.get("timestamp", ""),
|
# Fallback: fetch extra points and sort client-side
|
||||||
reverse=True
|
points, _ = await self.client.scroll(
|
||||||
)
|
collection_name=self.collection,
|
||||||
|
limit=limit * 5,
|
||||||
return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]]
|
with_payload=True
|
||||||
|
)
|
||||||
|
points = sorted(
|
||||||
|
points,
|
||||||
|
key=lambda p: p.payload.get("timestamp", ""),
|
||||||
|
reverse=True
|
||||||
|
)[:limit]
|
||||||
|
|
||||||
|
return [{"id": str(p.id), "payload": p.payload} for p in points]
|
||||||
|
|
||||||
async def delete_points(self, point_ids: List[str]) -> None:
|
async def delete_points(self, point_ids: List[str]) -> None:
|
||||||
"""Delete points by ID."""
|
"""Delete points by ID."""
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Global singleton instances for Vera-AI."""
|
"""Global singleton instances for Vera-AI."""
|
||||||
|
from typing import Optional
|
||||||
from .qdrant_service import QdrantService
|
from .qdrant_service import QdrantService
|
||||||
from .config import config
|
from .config import config
|
||||||
|
|
||||||
_qdrant_service: QdrantService = None
|
_qdrant_service: Optional[QdrantService] = None
|
||||||
|
|
||||||
|
|
||||||
def get_qdrant_service() -> QdrantService:
|
def get_qdrant_service() -> QdrantService:
|
||||||
|
|||||||
67
app/utils.py
67
app/utils.py
@@ -1,9 +1,10 @@
|
|||||||
"""Utility functions for vera-ai."""
|
"""Utility functions for vera-ai."""
|
||||||
from .config import config
|
from .config import config
|
||||||
|
from .singleton import get_qdrant_service
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Use cl100k_base encoding (GPT-4 compatible)
|
# Use cl100k_base encoding (GPT-4 compatible)
|
||||||
@@ -13,24 +14,6 @@ ENCODING = tiktoken.get_encoding("cl100k_base")
|
|||||||
PROMPTS_DIR = Path(os.environ.get("VERA_PROMPTS_DIR", "/app/prompts"))
|
PROMPTS_DIR = Path(os.environ.get("VERA_PROMPTS_DIR", "/app/prompts"))
|
||||||
STATIC_DIR = Path(os.environ.get("VERA_STATIC_DIR", "/app/static"))
|
STATIC_DIR = Path(os.environ.get("VERA_STATIC_DIR", "/app/static"))
|
||||||
|
|
||||||
# Global qdrant_service instance for utils
|
|
||||||
_qdrant_service = None
|
|
||||||
|
|
||||||
def get_qdrant_service():
|
|
||||||
"""Get or create the QdrantService singleton."""
|
|
||||||
global _qdrant_service
|
|
||||||
if _qdrant_service is None:
|
|
||||||
from .config import config
|
|
||||||
from .qdrant_service import QdrantService
|
|
||||||
_qdrant_service = QdrantService(
|
|
||||||
host=config.qdrant_host,
|
|
||||||
collection=config.qdrant_collection,
|
|
||||||
embedding_model=config.embedding_model,
|
|
||||||
vector_size=config.vector_size,
|
|
||||||
ollama_host=config.ollama_host
|
|
||||||
)
|
|
||||||
return _qdrant_service
|
|
||||||
|
|
||||||
def count_tokens(text: str) -> int:
|
def count_tokens(text: str) -> int:
|
||||||
"""Count tokens in text."""
|
"""Count tokens in text."""
|
||||||
if not text:
|
if not text:
|
||||||
@@ -56,7 +39,7 @@ def truncate_by_tokens(text: str, max_tokens: int) -> str:
|
|||||||
|
|
||||||
def filter_memories_by_time(memories: List[Dict], hours: int = 24) -> List[Dict]:
|
def filter_memories_by_time(memories: List[Dict], hours: int = 24) -> List[Dict]:
|
||||||
"""Filter memories from the last N hours."""
|
"""Filter memories from the last N hours."""
|
||||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=hours)
|
||||||
filtered = []
|
filtered = []
|
||||||
for mem in memories:
|
for mem in memories:
|
||||||
ts = mem.get("timestamp")
|
ts = mem.get("timestamp")
|
||||||
@@ -64,7 +47,7 @@ def filter_memories_by_time(memories: List[Dict], hours: int = 24) -> List[Dict]
|
|||||||
try:
|
try:
|
||||||
# Parse ISO timestamp
|
# Parse ISO timestamp
|
||||||
if isinstance(ts, str):
|
if isinstance(ts, str):
|
||||||
mem_time = datetime.fromisoformat(ts.replace("Z", "+00:00").replace("+00:00", ""))
|
mem_time = datetime.fromisoformat(ts.replace("Z", "")).replace(tzinfo=None)
|
||||||
else:
|
else:
|
||||||
mem_time = ts
|
mem_time = ts
|
||||||
if mem_time > cutoff:
|
if mem_time > cutoff:
|
||||||
@@ -100,15 +83,6 @@ def merge_memories(memories: List[Dict]) -> Dict:
|
|||||||
"ids": ids
|
"ids": ids
|
||||||
}
|
}
|
||||||
|
|
||||||
def calculate_token_budget(total_budget: int, system_ratio: float = 0.2,
|
|
||||||
semantic_ratio: float = 0.5, context_ratio: float = 0.3) -> Dict[int, int]:
|
|
||||||
"""Calculate token budgets for each layer."""
|
|
||||||
return {
|
|
||||||
"system": int(total_budget * system_ratio),
|
|
||||||
"semantic": int(total_budget * semantic_ratio),
|
|
||||||
"context": int(total_budget * context_ratio)
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_system_prompt() -> str:
|
def load_system_prompt() -> str:
|
||||||
"""Load system prompt from prompts directory."""
|
"""Load system prompt from prompts directory."""
|
||||||
import logging
|
import logging
|
||||||
@@ -219,36 +193,45 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# === LAYER 1: System Prompt ===
|
# === LAYER 1: System Prompt ===
|
||||||
system_content = ""
|
# Caller's system message passes through; systemprompt.md appends if non-empty.
|
||||||
|
caller_system = ""
|
||||||
for msg in incoming_messages:
|
for msg in incoming_messages:
|
||||||
if msg.get("role") == "system":
|
if msg.get("role") == "system":
|
||||||
system_content = msg.get("content", "")
|
caller_system = msg.get("content", "")
|
||||||
break
|
break
|
||||||
|
|
||||||
if system_prompt:
|
if caller_system and system_prompt:
|
||||||
system_content += "\n\n" + system_prompt
|
system_content = caller_system + "\n\n" + system_prompt
|
||||||
|
elif caller_system:
|
||||||
|
system_content = caller_system
|
||||||
|
elif system_prompt:
|
||||||
|
system_content = system_prompt
|
||||||
|
else:
|
||||||
|
system_content = ""
|
||||||
|
|
||||||
if system_content:
|
if system_content:
|
||||||
messages.append({"role": "system", "content": system_content})
|
messages.append({"role": "system", "content": system_content})
|
||||||
logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens")
|
logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens")
|
||||||
|
|
||||||
# === LAYER 2: Semantic (curated memories) ===
|
# === LAYER 2: Semantic (curated + raw memories) ===
|
||||||
qdrant = get_qdrant_service()
|
qdrant = get_qdrant_service()
|
||||||
semantic_results = await qdrant.semantic_search(
|
semantic_results = await qdrant.semantic_search(
|
||||||
query=search_context if search_context else user_question,
|
query=search_context if search_context else user_question,
|
||||||
limit=20,
|
limit=20,
|
||||||
score_threshold=config.semantic_score_threshold,
|
score_threshold=config.semantic_score_threshold,
|
||||||
entry_type="curated"
|
entry_types=["curated", "raw"]
|
||||||
)
|
)
|
||||||
|
|
||||||
semantic_messages = []
|
semantic_messages = []
|
||||||
semantic_tokens_used = 0
|
semantic_tokens_used = 0
|
||||||
|
semantic_ids = set()
|
||||||
|
|
||||||
for result in semantic_results:
|
for result in semantic_results:
|
||||||
|
semantic_ids.add(result.get("id"))
|
||||||
payload = result.get("payload", {})
|
payload = result.get("payload", {})
|
||||||
text = payload.get("text", "")
|
text = payload.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
# Parse curated turn into proper user/assistant messages
|
# Parse curated/raw turn into proper user/assistant messages
|
||||||
parsed = parse_curated_turn(text)
|
parsed = parse_curated_turn(text)
|
||||||
for msg in parsed:
|
for msg in parsed:
|
||||||
msg_tokens = count_tokens(msg.get("content", ""))
|
msg_tokens = count_tokens(msg.get("content", ""))
|
||||||
@@ -273,8 +256,10 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
|||||||
context_messages = []
|
context_messages = []
|
||||||
context_tokens_used = 0
|
context_tokens_used = 0
|
||||||
|
|
||||||
# Process oldest first for chronological order
|
# Process oldest first for chronological order, skip duplicates from Layer 2
|
||||||
for turn in reversed(recent_turns):
|
for turn in reversed(recent_turns):
|
||||||
|
if turn.get("id") in semantic_ids:
|
||||||
|
continue
|
||||||
payload = turn.get("payload", {})
|
payload = turn.get("payload", {})
|
||||||
text = payload.get("text", "")
|
text = payload.get("text", "")
|
||||||
entry_type = payload.get("type", "raw")
|
entry_type = payload.get("type", "raw")
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
fastapi>=0.104.0
|
fastapi==0.135.2
|
||||||
uvicorn[standard]>=0.24.0
|
uvicorn[standard]==0.42.0
|
||||||
httpx>=0.25.0
|
httpx==0.28.1
|
||||||
qdrant-client>=1.6.0
|
qdrant-client==1.17.1
|
||||||
ollama>=0.1.0
|
ollama==0.6.1
|
||||||
toml>=0.10.2
|
tiktoken==0.12.0
|
||||||
tiktoken>=0.5.0
|
apscheduler==3.11.2
|
||||||
apscheduler>=3.10.0
|
portalocker==3.2.0
|
||||||
pytest>=7.0.0
|
pytest==9.0.2
|
||||||
pytest-asyncio>=0.21.0
|
pytest-asyncio==1.3.0
|
||||||
pytest-cov>=4.0.0
|
pytest-cov==7.1.0
|
||||||
|
|||||||
62
tests/conftest.py
Normal file
62
tests/conftest.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Shared test fixtures using production-realistic data."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from app.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def production_config():
|
||||||
|
"""Config matching production deployment on deb8."""
|
||||||
|
config = MagicMock(spec=Config)
|
||||||
|
config.ollama_host = "http://10.0.0.10:11434"
|
||||||
|
config.qdrant_host = "http://10.0.0.22:6333"
|
||||||
|
config.qdrant_collection = "memories"
|
||||||
|
config.embedding_model = "snowflake-arctic-embed2"
|
||||||
|
config.semantic_token_budget = 25000
|
||||||
|
config.context_token_budget = 22000
|
||||||
|
config.semantic_search_turns = 2
|
||||||
|
config.semantic_score_threshold = 0.6
|
||||||
|
config.run_time = "02:00"
|
||||||
|
config.curator_model = "gpt-oss:120b"
|
||||||
|
config.debug = False
|
||||||
|
config.vector_size = 1024
|
||||||
|
config.cloud = MagicMock()
|
||||||
|
config.cloud.enabled = False
|
||||||
|
config.cloud.models = {}
|
||||||
|
config.cloud.get_cloud_model.return_value = None
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_qdrant_raw_payload():
|
||||||
|
"""Sample raw payload from production Qdrant."""
|
||||||
|
return {
|
||||||
|
"type": "raw",
|
||||||
|
"text": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z",
|
||||||
|
"timestamp": "2026-03-27T12:50:37.451593Z",
|
||||||
|
"role": "qa",
|
||||||
|
"content": "User: only change settings, not models\nAssistant: Changed semantic_token_budget from 25000 to 30000\nTimestamp: 2026-03-27T12:50:37.451593Z"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ollama_models():
|
||||||
|
"""Model list from production Ollama."""
|
||||||
|
return {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "snowflake-arctic-embed2:latest",
|
||||||
|
"model": "snowflake-arctic-embed2:latest",
|
||||||
|
"modified_at": "2026-02-16T16:43:44Z",
|
||||||
|
"size": 1160296718,
|
||||||
|
"details": {"family": "bert", "parameter_size": "566.70M", "quantization_level": "F16"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "gpt-oss:120b",
|
||||||
|
"model": "gpt-oss:120b",
|
||||||
|
"modified_at": "2026-03-11T12:45:48Z",
|
||||||
|
"size": 65369818941,
|
||||||
|
"details": {"family": "gptoss", "parameter_size": "116.8B", "quantization_level": "MXFP4"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -2,9 +2,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
def make_curator():
|
def make_curator():
|
||||||
@@ -77,14 +77,14 @@ class TestIsRecent:
|
|||||||
def test_memory_within_window(self):
|
def test_memory_within_window(self):
|
||||||
"""Memory timestamped 1 hour ago is recent (within 24h)."""
|
"""Memory timestamped 1 hour ago is recent (within 24h)."""
|
||||||
curator, _ = make_curator()
|
curator, _ = make_curator()
|
||||||
ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + "Z"
|
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)).isoformat() + "Z"
|
||||||
memory = {"timestamp": ts}
|
memory = {"timestamp": ts}
|
||||||
assert curator._is_recent(memory, hours=24) is True
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
def test_memory_outside_window(self):
|
def test_memory_outside_window(self):
|
||||||
"""Memory timestamped 48 hours ago is not recent."""
|
"""Memory timestamped 48 hours ago is not recent."""
|
||||||
curator, _ = make_curator()
|
curator, _ = make_curator()
|
||||||
ts = (datetime.utcnow() - timedelta(hours=48)).isoformat() + "Z"
|
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=48)).isoformat() + "Z"
|
||||||
memory = {"timestamp": ts}
|
memory = {"timestamp": ts}
|
||||||
assert curator._is_recent(memory, hours=24) is False
|
assert curator._is_recent(memory, hours=24) is False
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ class TestIsRecent:
|
|||||||
def test_boundary_edge_just_inside(self):
|
def test_boundary_edge_just_inside(self):
|
||||||
"""Memory at exactly hours-1 minutes ago should be recent."""
|
"""Memory at exactly hours-1 minutes ago should be recent."""
|
||||||
curator, _ = make_curator()
|
curator, _ = make_curator()
|
||||||
ts = (datetime.utcnow() - timedelta(hours=23, minutes=59)).isoformat() + "Z"
|
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=23, minutes=59)).isoformat() + "Z"
|
||||||
memory = {"timestamp": ts}
|
memory = {"timestamp": ts}
|
||||||
assert curator._is_recent(memory, hours=24) is True
|
assert curator._is_recent(memory, hours=24) is True
|
||||||
|
|
||||||
@@ -197,3 +197,294 @@ class TestAppendRuleToFile:
|
|||||||
target = prompts_dir / "newfile.md"
|
target = prompts_dir / "newfile.md"
|
||||||
assert target.exists()
|
assert target.exists()
|
||||||
assert "New rule here." in target.read_text()
|
assert "New rule here." in target.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatExistingMemories:
|
||||||
|
"""Tests for Curator._format_existing_memories."""
|
||||||
|
|
||||||
|
def test_empty_list_returns_no_memories_message(self):
|
||||||
|
"""Empty list returns a 'no memories' message."""
|
||||||
|
curator, _ = make_curator()
|
||||||
|
result = curator._format_existing_memories([])
|
||||||
|
assert "No existing curated memories" in result
|
||||||
|
|
||||||
|
def test_single_memory_formatted(self):
|
||||||
|
"""Single memory text is included in output."""
|
||||||
|
curator, _ = make_curator()
|
||||||
|
memories = [{"text": "User: hello\nAssistant: hi there"}]
|
||||||
|
result = curator._format_existing_memories(memories)
|
||||||
|
assert "hello" in result
|
||||||
|
assert "hi there" in result
|
||||||
|
|
||||||
|
def test_limits_to_last_20(self):
|
||||||
|
"""Only last 20 memories are included."""
|
||||||
|
curator, _ = make_curator()
|
||||||
|
memories = [{"text": f"memory {i}"} for i in range(30)]
|
||||||
|
result = curator._format_existing_memories(memories)
|
||||||
|
# Should contain memory 10-29 (last 20), not memory 0-9
|
||||||
|
assert "memory 29" in result
|
||||||
|
assert "memory 10" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallLlm:
|
||||||
|
"""Tests for Curator._call_llm."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_llm_returns_response(self):
|
||||||
|
"""_call_llm returns the response text from Ollama."""
|
||||||
|
curator, _ = make_curator()
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"response": "some LLM output"}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
result = await curator._call_llm("test prompt")
|
||||||
|
|
||||||
|
assert result == "some LLM output"
|
||||||
|
call_kwargs = mock_client.post.call_args
|
||||||
|
assert "test-model" in call_kwargs[1]["json"]["model"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_llm_returns_empty_on_error(self):
|
||||||
|
"""_call_llm returns empty string when Ollama errors."""
|
||||||
|
curator, _ = make_curator()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(side_effect=Exception("connection refused"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
result = await curator._call_llm("test prompt")
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestCuratorRun:
|
||||||
|
"""Tests for Curator.run() method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_no_raw_memories_exits_early(self):
|
||||||
|
"""run() exits early when no raw memories found."""
|
||||||
|
curator, mock_qdrant = make_curator()
|
||||||
|
|
||||||
|
# Mock scroll to return no points
|
||||||
|
mock_qdrant.client = AsyncMock()
|
||||||
|
mock_qdrant.client.scroll = AsyncMock(return_value=([], None))
|
||||||
|
mock_qdrant.collection = "memories"
|
||||||
|
|
||||||
|
await curator.run()
|
||||||
|
# Should not call LLM since there are no raw memories
|
||||||
|
# If it got here without error, that's success
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_processes_raw_memories(self):
|
||||||
|
"""run() processes raw memories and stores curated results."""
|
||||||
|
curator, mock_qdrant = make_curator()
|
||||||
|
|
||||||
|
# Create mock points
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "point-1"
|
||||||
|
mock_point.payload = {
|
||||||
|
"type": "raw",
|
||||||
|
"text": "User: hello\nAssistant: hi",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_qdrant.client = AsyncMock()
|
||||||
|
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||||
|
mock_qdrant.collection = "memories"
|
||||||
|
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||||
|
mock_qdrant.delete_points = AsyncMock()
|
||||||
|
|
||||||
|
llm_response = json.dumps({
|
||||||
|
"new_curated_turns": [{"content": "User: hello\nAssistant: hi"}],
|
||||||
|
"permanent_rules": [],
|
||||||
|
"deletions": [],
|
||||||
|
"summary": "Curated one turn"
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)):
|
||||||
|
await curator.run()
|
||||||
|
|
||||||
|
mock_qdrant.store_turn.assert_called_once()
|
||||||
|
mock_qdrant.delete_points.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_monthly_mode_on_day_01(self):
|
||||||
|
"""run() uses monthly mode on day 01, processing all raw memories."""
|
||||||
|
curator, mock_qdrant = make_curator()
|
||||||
|
|
||||||
|
# Create a mock point with an old timestamp (outside 24h window)
|
||||||
|
old_ts = (datetime.now(timezone.utc) - timedelta(hours=72)).isoformat().replace("+00:00", "Z")
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "old-point"
|
||||||
|
mock_point.payload = {
|
||||||
|
"type": "raw",
|
||||||
|
"text": "User: old question\nAssistant: old answer",
|
||||||
|
"timestamp": old_ts,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_qdrant.client = AsyncMock()
|
||||||
|
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||||
|
mock_qdrant.collection = "memories"
|
||||||
|
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||||
|
mock_qdrant.delete_points = AsyncMock()
|
||||||
|
|
||||||
|
llm_response = json.dumps({
|
||||||
|
"new_curated_turns": [],
|
||||||
|
"permanent_rules": [],
|
||||||
|
"deletions": [],
|
||||||
|
"summary": "Nothing to curate"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Mock day 01
|
||||||
|
mock_now = datetime(2026, 4, 1, 2, 0, 0, tzinfo=timezone.utc)
|
||||||
|
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \
|
||||||
|
patch("app.curator.datetime") as mock_dt:
|
||||||
|
mock_dt.now.return_value = mock_now
|
||||||
|
mock_dt.fromisoformat = datetime.fromisoformat
|
||||||
|
mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw)
|
||||||
|
await curator.run()
|
||||||
|
|
||||||
|
# In monthly mode, even old memories are processed, so LLM should be called
|
||||||
|
# and delete_points should be called for the raw memory
|
||||||
|
mock_qdrant.delete_points.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_handles_permanent_rules(self):
|
||||||
|
"""run() appends permanent rules to prompt files."""
|
||||||
|
curator, mock_qdrant = make_curator()
|
||||||
|
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "point-1"
|
||||||
|
mock_point.payload = {
|
||||||
|
"type": "raw",
|
||||||
|
"text": "User: remember this\nAssistant: ok",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_qdrant.client = AsyncMock()
|
||||||
|
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||||
|
mock_qdrant.collection = "memories"
|
||||||
|
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||||
|
mock_qdrant.delete_points = AsyncMock()
|
||||||
|
|
||||||
|
llm_response = json.dumps({
|
||||||
|
"new_curated_turns": [],
|
||||||
|
"permanent_rules": [{"rule": "Always be concise.", "target_file": "systemprompt.md"}],
|
||||||
|
"deletions": [],
|
||||||
|
"summary": "Added a rule"
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)), \
|
||||||
|
patch.object(curator, "_append_rule_to_file", AsyncMock()) as mock_append:
|
||||||
|
await curator.run()
|
||||||
|
|
||||||
|
mock_append.assert_called_once_with("systemprompt.md", "Always be concise.")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_handles_deletions(self):
|
||||||
|
"""run() deletes specified point IDs when they exist in the database."""
|
||||||
|
curator, mock_qdrant = make_curator()
|
||||||
|
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "point-1"
|
||||||
|
mock_point.payload = {
|
||||||
|
"type": "raw",
|
||||||
|
"text": "User: delete me\nAssistant: ok",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_qdrant.client = AsyncMock()
|
||||||
|
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||||
|
mock_qdrant.collection = "memories"
|
||||||
|
mock_qdrant.store_turn = AsyncMock(return_value="new-id")
|
||||||
|
mock_qdrant.delete_points = AsyncMock()
|
||||||
|
|
||||||
|
llm_response = json.dumps({
|
||||||
|
"new_curated_turns": [],
|
||||||
|
"permanent_rules": [],
|
||||||
|
"deletions": ["point-1"],
|
||||||
|
"summary": "Deleted one"
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch.object(curator, "_call_llm", AsyncMock(return_value=llm_response)):
|
||||||
|
await curator.run()
|
||||||
|
|
||||||
|
# delete_points should be called at least twice: once for valid deletions, once for processed raw
|
||||||
|
assert mock_qdrant.delete_points.call_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_handles_llm_parse_failure(self):
|
||||||
|
"""run() handles LLM returning unparseable response gracefully."""
|
||||||
|
curator, mock_qdrant = make_curator()
|
||||||
|
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "point-1"
|
||||||
|
mock_point.payload = {
|
||||||
|
"type": "raw",
|
||||||
|
"text": "User: test\nAssistant: ok",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_qdrant.client = AsyncMock()
|
||||||
|
mock_qdrant.client.scroll = AsyncMock(return_value=([mock_point], None))
|
||||||
|
mock_qdrant.collection = "memories"
|
||||||
|
|
||||||
|
with patch.object(curator, "_call_llm", AsyncMock(return_value="not json at all!!!")):
|
||||||
|
# Should not raise - just return early
|
||||||
|
await curator.run()
|
||||||
|
|
||||||
|
# store_turn should NOT be called since parsing failed
|
||||||
|
mock_qdrant.store_turn = AsyncMock()
|
||||||
|
mock_qdrant.store_turn.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadCuratorPrompt:
|
||||||
|
"""Tests for load_curator_prompt function."""
|
||||||
|
|
||||||
|
def test_loads_from_prompts_dir(self, tmp_path):
|
||||||
|
"""load_curator_prompt loads from PROMPTS_DIR."""
|
||||||
|
import app.curator as curator_module
|
||||||
|
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
(prompts_dir / "curator_prompt.md").write_text("Test curator prompt")
|
||||||
|
|
||||||
|
with patch.object(curator_module, "PROMPTS_DIR", prompts_dir):
|
||||||
|
from app.curator import load_curator_prompt
|
||||||
|
result = load_curator_prompt()
|
||||||
|
|
||||||
|
assert result == "Test curator prompt"
|
||||||
|
|
||||||
|
def test_falls_back_to_static_dir(self, tmp_path):
|
||||||
|
"""load_curator_prompt falls back to STATIC_DIR."""
|
||||||
|
import app.curator as curator_module
|
||||||
|
|
||||||
|
prompts_dir = tmp_path / "prompts" # does not exist
|
||||||
|
static_dir = tmp_path / "static"
|
||||||
|
static_dir.mkdir()
|
||||||
|
(static_dir / "curator_prompt.md").write_text("Static prompt")
|
||||||
|
|
||||||
|
with patch.object(curator_module, "PROMPTS_DIR", prompts_dir), \
|
||||||
|
patch.object(curator_module, "STATIC_DIR", static_dir):
|
||||||
|
from app.curator import load_curator_prompt
|
||||||
|
result = load_curator_prompt()
|
||||||
|
|
||||||
|
assert result == "Static prompt"
|
||||||
|
|
||||||
|
def test_raises_when_not_found(self, tmp_path):
|
||||||
|
"""load_curator_prompt raises FileNotFoundError when file missing."""
|
||||||
|
import app.curator as curator_module
|
||||||
|
|
||||||
|
with patch.object(curator_module, "PROMPTS_DIR", tmp_path / "nope"), \
|
||||||
|
patch.object(curator_module, "STATIC_DIR", tmp_path / "also_nope"):
|
||||||
|
from app.curator import load_curator_prompt
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_curator_prompt()
|
||||||
|
|||||||
@@ -349,3 +349,83 @@ class TestApiChatStreaming:
|
|||||||
# Response body should contain both chunks concatenated
|
# Response body should contain both chunks concatenated
|
||||||
body_text = resp.text
|
body_text = resp.text
|
||||||
assert "Hello" in body_text or len(body_text) > 0
|
assert "Hello" in body_text or len(body_text) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Health check edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHealthCheckEdgeCases:
|
||||||
|
def test_health_ollama_timeout(self, app_with_mocks):
|
||||||
|
"""GET / handles Ollama timeout gracefully."""
|
||||||
|
import httpx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
vera_app, _ = app_with_mocks
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.get = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
with TestClient(vera_app, raise_server_exceptions=True) as client:
|
||||||
|
resp = client.get("/")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["ollama"] == "unreachable"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /curator/run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTriggerCurator:
|
||||||
|
def test_trigger_curator_endpoint(self, app_with_mocks):
|
||||||
|
"""POST /curator/run triggers curation and returns status."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import app.main as main_module
|
||||||
|
|
||||||
|
vera_app, _ = app_with_mocks
|
||||||
|
|
||||||
|
mock_curator = MagicMock()
|
||||||
|
mock_curator.run = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(main_module, "curator", mock_curator):
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
resp = client.post("/curator/run")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "curation completed"
|
||||||
|
mock_curator.run.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Proxy catch-all
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestProxyAll:
|
||||||
|
def test_non_chat_api_proxied(self, app_with_mocks):
|
||||||
|
"""Non-chat API paths are proxied to Ollama."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
vera_app, _ = app_with_mocks
|
||||||
|
|
||||||
|
async def fake_aiter_bytes():
|
||||||
|
yield b'{"status": "ok"}'
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.headers = {"content-type": "application/json"}
|
||||||
|
mock_resp.aiter_bytes = fake_aiter_bytes
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_instance.request = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client_instance):
|
||||||
|
with TestClient(vera_app) as client:
|
||||||
|
resp = client.get("/api/show")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|||||||
@@ -219,3 +219,94 @@ class TestHandleChatNonStreaming:
|
|||||||
# The wrapper should be stripped
|
# The wrapper should be stripped
|
||||||
assert "Memory context" not in stored_question
|
assert "Memory context" not in stored_question
|
||||||
assert "What is the answer?" in stored_question
|
assert "What is the answer?" in stored_question
|
||||||
|
|
||||||
|
|
||||||
|
class TestDebugLog:
|
||||||
|
"""Tests for debug_log function."""
|
||||||
|
|
||||||
|
def test_debug_log_writes_json_when_enabled(self, tmp_path):
|
||||||
|
"""Debug log appends valid JSON line to file when debug=True."""
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.debug = True
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.config", mock_config), \
|
||||||
|
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
|
||||||
|
from app.proxy_handler import debug_log
|
||||||
|
debug_log("test_cat", "test message", {"key": "value"})
|
||||||
|
|
||||||
|
log_files = list(tmp_path.glob("debug_*.log"))
|
||||||
|
assert len(log_files) == 1
|
||||||
|
content = log_files[0].read_text().strip()
|
||||||
|
entry = json.loads(content)
|
||||||
|
assert entry["category"] == "test_cat"
|
||||||
|
assert entry["message"] == "test message"
|
||||||
|
assert entry["data"]["key"] == "value"
|
||||||
|
|
||||||
|
def test_debug_log_skips_when_disabled(self, tmp_path):
|
||||||
|
"""Debug log does nothing when debug=False."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.debug = False
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.config", mock_config), \
|
||||||
|
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
|
||||||
|
from app.proxy_handler import debug_log
|
||||||
|
debug_log("test_cat", "test message")
|
||||||
|
|
||||||
|
log_files = list(tmp_path.glob("debug_*.log"))
|
||||||
|
assert len(log_files) == 0
|
||||||
|
|
||||||
|
def test_debug_log_without_data(self, tmp_path):
|
||||||
|
"""Debug log works without optional data parameter."""
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.debug = True
|
||||||
|
|
||||||
|
with patch("app.proxy_handler.config", mock_config), \
|
||||||
|
patch("app.proxy_handler.DEBUG_LOG_DIR", tmp_path):
|
||||||
|
from app.proxy_handler import debug_log
|
||||||
|
debug_log("simple_cat", "no data here")
|
||||||
|
|
||||||
|
log_files = list(tmp_path.glob("debug_*.log"))
|
||||||
|
assert len(log_files) == 1
|
||||||
|
entry = json.loads(log_files[0].read_text().strip())
|
||||||
|
assert "data" not in entry
|
||||||
|
assert entry["category"] == "simple_cat"
|
||||||
|
|
||||||
|
|
||||||
|
class TestForwardToOllama:
|
||||||
|
"""Tests for forward_to_ollama function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_forwards_request_to_ollama(self):
|
||||||
|
"""forward_to_ollama proxies request to Ollama host."""
|
||||||
|
from app.proxy_handler import forward_to_ollama
|
||||||
|
from unittest.mock import patch, AsyncMock, MagicMock
|
||||||
|
|
||||||
|
mock_request = AsyncMock()
|
||||||
|
mock_request.body = AsyncMock(return_value=b'{"model": "llama3"}')
|
||||||
|
mock_request.method = "POST"
|
||||||
|
mock_request.headers = {"content-type": "application/json", "content-length": "20"}
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.request = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
result = await forward_to_ollama(mock_request, "/api/show")
|
||||||
|
|
||||||
|
assert result == mock_resp
|
||||||
|
mock_client.request.assert_called_once()
|
||||||
|
call_kwargs = mock_client.request.call_args
|
||||||
|
assert call_kwargs[1]["method"] == "POST"
|
||||||
|
assert "/api/show" in call_kwargs[1]["url"]
|
||||||
|
|||||||
256
tests/test_qdrant_service.py
Normal file
256
tests/test_qdrant_service.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""Tests for QdrantService — all Qdrant and Ollama calls are mocked."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def make_qdrant_service():
|
||||||
|
"""Create a QdrantService with mocked AsyncQdrantClient."""
|
||||||
|
with patch("app.qdrant_service.AsyncQdrantClient") as MockClient:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
MockClient.return_value = mock_client
|
||||||
|
|
||||||
|
from app.qdrant_service import QdrantService
|
||||||
|
svc = QdrantService(
|
||||||
|
host="http://localhost:6333",
|
||||||
|
collection="test_memories",
|
||||||
|
embedding_model="snowflake-arctic-embed2",
|
||||||
|
vector_size=1024,
|
||||||
|
ollama_host="http://localhost:11434",
|
||||||
|
)
|
||||||
|
|
||||||
|
return svc, mock_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureCollection:
|
||||||
|
"""Tests for _ensure_collection."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_creates_collection_when_missing(self):
|
||||||
|
"""Creates collection if it does not exist."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
mock_client.get_collection = AsyncMock(side_effect=Exception("not found"))
|
||||||
|
mock_client.create_collection = AsyncMock()
|
||||||
|
|
||||||
|
await svc._ensure_collection()
|
||||||
|
|
||||||
|
mock_client.create_collection.assert_called_once()
|
||||||
|
assert svc._collection_ensured is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_if_collection_exists(self):
|
||||||
|
"""Does not create collection if it already exists."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
mock_client.get_collection = AsyncMock(return_value=MagicMock())
|
||||||
|
mock_client.create_collection = AsyncMock()
|
||||||
|
|
||||||
|
await svc._ensure_collection()
|
||||||
|
|
||||||
|
mock_client.create_collection.assert_not_called()
|
||||||
|
assert svc._collection_ensured is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_if_already_ensured(self):
|
||||||
|
"""Skips entirely if _collection_ensured is True."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
mock_client.get_collection = AsyncMock()
|
||||||
|
|
||||||
|
await svc._ensure_collection()
|
||||||
|
|
||||||
|
mock_client.get_collection.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetEmbedding:
|
||||||
|
"""Tests for get_embedding."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_embedding_vector(self):
|
||||||
|
"""Returns embedding from Ollama response."""
|
||||||
|
svc, _ = make_qdrant_service()
|
||||||
|
fake_embedding = [0.1] * 1024
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"embedding": fake_embedding}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
result = await svc.get_embedding("test text")
|
||||||
|
|
||||||
|
assert result == fake_embedding
|
||||||
|
assert len(result) == 1024
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoreTurn:
|
||||||
|
"""Tests for store_turn."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stores_raw_user_turn(self):
|
||||||
|
"""Stores a user turn with proper payload."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
mock_client.upsert = AsyncMock()
|
||||||
|
|
||||||
|
fake_embedding = [0.1] * 1024
|
||||||
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||||
|
point_id = await svc.store_turn(role="user", content="hello world")
|
||||||
|
|
||||||
|
assert isinstance(point_id, str)
|
||||||
|
mock_client.upsert.assert_called_once()
|
||||||
|
call_args = mock_client.upsert.call_args
|
||||||
|
point = call_args[1]["points"][0]
|
||||||
|
assert point.payload["type"] == "raw"
|
||||||
|
assert point.payload["role"] == "user"
|
||||||
|
assert "User: hello world" in point.payload["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stores_curated_turn(self):
|
||||||
|
"""Stores a curated turn without role prefix in text."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
mock_client.upsert = AsyncMock()
|
||||||
|
|
||||||
|
fake_embedding = [0.1] * 1024
|
||||||
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||||
|
point_id = await svc.store_turn(
|
||||||
|
role="curated",
|
||||||
|
content="User: q\nAssistant: a",
|
||||||
|
entry_type="curated"
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_client.upsert.call_args
|
||||||
|
point = call_args[1]["points"][0]
|
||||||
|
assert point.payload["type"] == "curated"
|
||||||
|
# Curated text should be the content directly, not prefixed
|
||||||
|
assert point.payload["text"] == "User: q\nAssistant: a"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stores_with_topic_and_metadata(self):
|
||||||
|
"""Stores turn with optional topic and metadata."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
mock_client.upsert = AsyncMock()
|
||||||
|
|
||||||
|
fake_embedding = [0.1] * 1024
|
||||||
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||||
|
await svc.store_turn(
|
||||||
|
role="assistant",
|
||||||
|
content="some response",
|
||||||
|
topic="python",
|
||||||
|
metadata={"source": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_client.upsert.call_args
|
||||||
|
point = call_args[1]["points"][0]
|
||||||
|
assert point.payload["topic"] == "python"
|
||||||
|
assert point.payload["source"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoreQaTurn:
|
||||||
|
"""Tests for store_qa_turn."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stores_qa_turn(self):
|
||||||
|
"""Stores a complete Q&A turn."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
mock_client.upsert = AsyncMock()
|
||||||
|
|
||||||
|
fake_embedding = [0.1] * 1024
|
||||||
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||||
|
point_id = await svc.store_qa_turn("What is Python?", "A programming language.")
|
||||||
|
|
||||||
|
assert isinstance(point_id, str)
|
||||||
|
call_args = mock_client.upsert.call_args
|
||||||
|
point = call_args[1]["points"][0]
|
||||||
|
assert point.payload["type"] == "raw"
|
||||||
|
assert point.payload["role"] == "qa"
|
||||||
|
assert "What is Python?" in point.payload["text"]
|
||||||
|
assert "A programming language." in point.payload["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSearch:
|
||||||
|
"""Tests for semantic_search."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_matching_results(self):
|
||||||
|
"""Returns formatted search results."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "result-1"
|
||||||
|
mock_point.score = 0.85
|
||||||
|
mock_point.payload = {"text": "User: hello\nAssistant: hi", "type": "curated"}
|
||||||
|
|
||||||
|
mock_query_result = MagicMock()
|
||||||
|
mock_query_result.points = [mock_point]
|
||||||
|
mock_client.query_points = AsyncMock(return_value=mock_query_result)
|
||||||
|
|
||||||
|
fake_embedding = [0.1] * 1024
|
||||||
|
with patch.object(svc, "get_embedding", AsyncMock(return_value=fake_embedding)):
|
||||||
|
results = await svc.semantic_search("hello", limit=10, score_threshold=0.6)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["id"] == "result-1"
|
||||||
|
assert results[0]["score"] == 0.85
|
||||||
|
assert results[0]["payload"]["type"] == "curated"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRecentTurns:
|
||||||
|
"""Tests for get_recent_turns."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_sorted_recent_turns(self):
|
||||||
|
"""Returns turns sorted by timestamp descending."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
svc._collection_ensured = True
|
||||||
|
|
||||||
|
mock_point1 = MagicMock()
|
||||||
|
mock_point1.id = "old"
|
||||||
|
mock_point1.payload = {"timestamp": "2026-01-01T00:00:00Z", "text": "old turn"}
|
||||||
|
|
||||||
|
mock_point2 = MagicMock()
|
||||||
|
mock_point2.id = "new"
|
||||||
|
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
|
||||||
|
|
||||||
|
# OrderBy returns server-sorted results (newest first)
|
||||||
|
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
|
||||||
|
|
||||||
|
results = await svc.get_recent_turns(limit=2)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# Newest first (server-sorted via OrderBy)
|
||||||
|
assert results[0]["id"] == "new"
|
||||||
|
assert results[1]["id"] == "old"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeletePoints:
|
||||||
|
"""Tests for delete_points."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deletes_by_ids(self):
|
||||||
|
"""Deletes points by their IDs."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
mock_client.delete = AsyncMock()
|
||||||
|
|
||||||
|
await svc.delete_points(["id1", "id2"])
|
||||||
|
|
||||||
|
mock_client.delete.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestClose:
|
||||||
|
"""Tests for close."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_closes_client(self):
|
||||||
|
"""Closes the async Qdrant client."""
|
||||||
|
svc, mock_client = make_qdrant_service()
|
||||||
|
mock_client.close = AsyncMock()
|
||||||
|
|
||||||
|
await svc.close()
|
||||||
|
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for utility functions."""
|
"""Tests for utility functions."""
|
||||||
import pytest
|
import pytest
|
||||||
from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from app.utils import count_tokens, truncate_by_tokens, parse_curated_turn, build_augmented_messages, count_messages_tokens
|
||||||
|
|
||||||
|
|
||||||
class TestCountTokens:
|
class TestCountTokens:
|
||||||
@@ -85,25 +86,95 @@ Assistant: Yes, very popular."""
|
|||||||
assert "Line 3" in result[0]["content"]
|
assert "Line 3" in result[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCountMessagesTokens:
|
||||||
|
"""Tests for count_messages_tokens function."""
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
"""Empty message list returns 0."""
|
||||||
|
assert count_messages_tokens([]) == 0
|
||||||
|
|
||||||
|
def test_single_message(self):
|
||||||
|
"""Single message counts tokens of its content."""
|
||||||
|
msgs = [{"role": "user", "content": "Hello world"}]
|
||||||
|
result = count_messages_tokens(msgs)
|
||||||
|
assert result > 0
|
||||||
|
|
||||||
|
def test_multiple_messages(self):
|
||||||
|
"""Multiple messages sum up their token counts."""
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there, how can I help you today?"},
|
||||||
|
]
|
||||||
|
result = count_messages_tokens(msgs)
|
||||||
|
assert result > count_messages_tokens([msgs[0]])
|
||||||
|
|
||||||
|
def test_message_without_content(self):
|
||||||
|
"""Message without content field contributes 0 tokens."""
|
||||||
|
msgs = [{"role": "system"}]
|
||||||
|
assert count_messages_tokens(msgs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadSystemPrompt:
|
||||||
|
"""Tests for load_system_prompt function."""
|
||||||
|
|
||||||
|
def test_loads_from_prompts_dir(self, tmp_path):
|
||||||
|
"""Loads systemprompt.md from PROMPTS_DIR."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
prompts_dir.mkdir()
|
||||||
|
(prompts_dir / "systemprompt.md").write_text("You are Vera.")
|
||||||
|
|
||||||
|
with patch.object(utils_module, "PROMPTS_DIR", prompts_dir):
|
||||||
|
result = utils_module.load_system_prompt()
|
||||||
|
|
||||||
|
assert result == "You are Vera."
|
||||||
|
|
||||||
|
def test_falls_back_to_static_dir(self, tmp_path):
|
||||||
|
"""Falls back to STATIC_DIR when PROMPTS_DIR has no file."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
prompts_dir = tmp_path / "no_prompts" # does not exist
|
||||||
|
static_dir = tmp_path / "static"
|
||||||
|
static_dir.mkdir()
|
||||||
|
(static_dir / "systemprompt.md").write_text("Static Vera.")
|
||||||
|
|
||||||
|
with patch.object(utils_module, "PROMPTS_DIR", prompts_dir), \
|
||||||
|
patch.object(utils_module, "STATIC_DIR", static_dir):
|
||||||
|
result = utils_module.load_system_prompt()
|
||||||
|
|
||||||
|
assert result == "Static Vera."
|
||||||
|
|
||||||
|
def test_returns_empty_when_not_found(self, tmp_path):
|
||||||
|
"""Returns empty string when systemprompt.md not found anywhere."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
with patch.object(utils_module, "PROMPTS_DIR", tmp_path / "nope"), \
|
||||||
|
patch.object(utils_module, "STATIC_DIR", tmp_path / "also_nope"):
|
||||||
|
result = utils_module.load_system_prompt()
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
class TestFilterMemoriesByTime:
|
class TestFilterMemoriesByTime:
|
||||||
"""Tests for filter_memories_by_time function."""
|
"""Tests for filter_memories_by_time function."""
|
||||||
|
|
||||||
def test_includes_recent_memory(self):
|
def test_includes_recent_memory(self):
|
||||||
"""Memory with timestamp in the last 24h should be included."""
|
"""Memory with timestamp in the last 24h should be included."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from app.utils import filter_memories_by_time
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
ts = (datetime.utcnow() - timedelta(hours=1)).isoformat()
|
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)).isoformat()
|
||||||
memories = [{"timestamp": ts, "text": "recent"}]
|
memories = [{"timestamp": ts, "text": "recent"}]
|
||||||
result = filter_memories_by_time(memories, hours=24)
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
def test_excludes_old_memory(self):
|
def test_excludes_old_memory(self):
|
||||||
"""Memory older than cutoff should be excluded."""
|
"""Memory older than cutoff should be excluded."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from app.utils import filter_memories_by_time
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
ts = (datetime.utcnow() - timedelta(hours=48)).isoformat()
|
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=48)).isoformat()
|
||||||
memories = [{"timestamp": ts, "text": "old"}]
|
memories = [{"timestamp": ts, "text": "old"}]
|
||||||
result = filter_memories_by_time(memories, hours=24)
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
assert len(result) == 0
|
assert len(result) == 0
|
||||||
@@ -124,6 +195,16 @@ class TestFilterMemoriesByTime:
|
|||||||
result = filter_memories_by_time(memories, hours=24)
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_z_suffix_old_timestamp_excluded(self):
|
||||||
|
"""Regression: chained .replace() was not properly handling Z suffix on old timestamps."""
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
|
old_ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=48)).isoformat() + "Z"
|
||||||
|
memories = [{"timestamp": old_ts, "text": "old with Z"}]
|
||||||
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
|
assert len(result) == 0, f"Old Z-suffixed timestamp should be excluded but wasn't: {old_ts}"
|
||||||
|
|
||||||
def test_empty_list(self):
|
def test_empty_list(self):
|
||||||
"""Empty input returns empty list."""
|
"""Empty input returns empty list."""
|
||||||
from app.utils import filter_memories_by_time
|
from app.utils import filter_memories_by_time
|
||||||
@@ -132,10 +213,10 @@ class TestFilterMemoriesByTime:
|
|||||||
|
|
||||||
def test_z_suffix_timestamp(self):
|
def test_z_suffix_timestamp(self):
|
||||||
"""ISO timestamp with Z suffix should be handled correctly."""
|
"""ISO timestamp with Z suffix should be handled correctly."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from app.utils import filter_memories_by_time
|
from app.utils import filter_memories_by_time
|
||||||
|
|
||||||
ts = (datetime.utcnow() - timedelta(hours=1)).isoformat() + "Z"
|
ts = (datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)).isoformat() + "Z"
|
||||||
memories = [{"timestamp": ts, "text": "recent with Z"}]
|
memories = [{"timestamp": ts, "text": "recent with Z"}]
|
||||||
result = filter_memories_by_time(memories, hours=24)
|
result = filter_memories_by_time(memories, hours=24)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
@@ -190,37 +271,6 @@ class TestMergeMemories:
|
|||||||
assert len(result["ids"]) == 2
|
assert len(result["ids"]) == 2
|
||||||
|
|
||||||
|
|
||||||
class TestCalculateTokenBudget:
|
|
||||||
"""Tests for calculate_token_budget function."""
|
|
||||||
|
|
||||||
def test_default_ratios_sum(self):
|
|
||||||
"""Default ratios should sum to 1.0 (system+semantic+context)."""
|
|
||||||
from app.utils import calculate_token_budget
|
|
||||||
|
|
||||||
result = calculate_token_budget(1000)
|
|
||||||
assert result["system"] + result["semantic"] + result["context"] == 1000
|
|
||||||
|
|
||||||
def test_custom_ratios(self):
|
|
||||||
"""Custom ratios should produce correct proportional budgets."""
|
|
||||||
from app.utils import calculate_token_budget
|
|
||||||
|
|
||||||
result = calculate_token_budget(
|
|
||||||
100, system_ratio=0.1, semantic_ratio=0.6, context_ratio=0.3
|
|
||||||
)
|
|
||||||
assert result["system"] == 10
|
|
||||||
assert result["semantic"] == 60
|
|
||||||
assert result["context"] == 30
|
|
||||||
|
|
||||||
def test_zero_budget(self):
|
|
||||||
"""Zero total budget yields all zeros."""
|
|
||||||
from app.utils import calculate_token_budget
|
|
||||||
|
|
||||||
result = calculate_token_budget(0)
|
|
||||||
assert result["system"] == 0
|
|
||||||
assert result["semantic"] == 0
|
|
||||||
assert result["context"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuildAugmentedMessages:
|
class TestBuildAugmentedMessages:
|
||||||
"""Tests for build_augmented_messages function (mocked I/O)."""
|
"""Tests for build_augmented_messages function (mocked I/O)."""
|
||||||
|
|
||||||
@@ -316,4 +366,72 @@ class TestBuildAugmentedMessages:
|
|||||||
)
|
)
|
||||||
|
|
||||||
contents = [m["content"] for m in result]
|
contents = [m["content"] for m in result]
|
||||||
assert any("Old question" in c or "Old answer" in c for c in contents)
|
assert any("Old question" in c or "Old answer" in c for c in contents)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_system_prompt_appends_to_caller_system(self):
|
||||||
|
"""systemprompt.md content appends to caller's system message."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value="Vera memory context"), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
incoming = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
result = await build_augmented_messages(incoming)
|
||||||
|
|
||||||
|
system_msg = result[0]
|
||||||
|
assert system_msg["role"] == "system"
|
||||||
|
assert system_msg["content"] == "You are a helpful assistant.\n\nVera memory context"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_system_prompt_passthrough(self):
|
||||||
|
"""When systemprompt.md is empty, only caller's system message passes through."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
incoming = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
result = await build_augmented_messages(incoming)
|
||||||
|
|
||||||
|
system_msg = result[0]
|
||||||
|
assert system_msg["content"] == "You are a helpful assistant."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_caller_system_with_vera_prompt(self):
|
||||||
|
"""When caller sends no system message but systemprompt.md exists, use vera prompt."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value="Vera memory context"), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
incoming = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await build_augmented_messages(incoming)
|
||||||
|
|
||||||
|
system_msg = result[0]
|
||||||
|
assert system_msg["role"] == "system"
|
||||||
|
assert system_msg["content"] == "Vera memory context"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_system_anywhere(self):
|
||||||
|
"""When neither caller nor systemprompt.md provides system content, no system message."""
|
||||||
|
import app.utils as utils_module
|
||||||
|
|
||||||
|
mock_qdrant = self._make_qdrant_mock()
|
||||||
|
|
||||||
|
with patch.object(utils_module, "load_system_prompt", return_value=""), \
|
||||||
|
patch.object(utils_module, "get_qdrant_service", return_value=mock_qdrant):
|
||||||
|
incoming = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await build_augmented_messages(incoming)
|
||||||
|
|
||||||
|
# First message should be user, not system
|
||||||
|
assert result[0]["role"] == "user"
|
||||||
Reference in New Issue
Block a user