Initial commit: Vera-AI v2 with async Qdrant, singleton pattern, monthly curation, and configurable UID/GID/TZ
Features: - AsyncQdrantClient for non-blocking Qdrant operations - Singleton pattern for QdrantService - Monthly full curation (day 1 at 03:00) - Configurable UID/GID for Docker - Timezone support via TZ env var - Configurable log directory (VERA_LOG_DIR) - Volume mounts for config/, prompts/, logs/ - Standard Docker format with .env file Fixes: - Removed unused system_token_budget - Added semantic_score_threshold config - Fixed streaming response handling - Python-based healthcheck (no curl dependency)
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
121
app/config.py
Normal file
121
app/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# app/config.py
|
||||
import toml
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Embedding model dimensions
|
||||
EMBEDDING_DIMS = {
|
||||
"nomic-embed-text": 768,
|
||||
"snowflake-arctic-embed2": 1024,
|
||||
"mxbai-embed-large": 1024,
|
||||
}
|
||||
|
||||
# Configurable paths (can be overridden via environment)
|
||||
CONFIG_DIR = Path(os.environ.get("VERA_CONFIG_DIR", "/app/config"))
|
||||
PROMPTS_DIR = Path(os.environ.get("VERA_PROMPTS_DIR", "/app/prompts"))
|
||||
STATIC_DIR = Path(os.environ.get("VERA_STATIC_DIR", "/app/static"))
|
||||
|
||||
@dataclass
|
||||
class CloudConfig:
|
||||
enabled: bool = False
|
||||
api_base: str = ""
|
||||
api_key_env: str = "OPENROUTER_API_KEY"
|
||||
models: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def api_key(self) -> Optional[str]:
|
||||
return os.environ.get(self.api_key_env)
|
||||
|
||||
def get_cloud_model(self, local_name: str) -> Optional[str]:
|
||||
"""Get cloud model ID for a local model name."""
|
||||
return self.models.get(local_name)
|
||||
|
||||
def is_cloud_model(self, local_name: str) -> bool:
|
||||
"""Check if a Model name should be routed to cloud."""
|
||||
return local_name in self.models
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
ollama_host: str = "http://10.0.0.10:11434"
|
||||
qdrant_host: str = "http://10.0.0.22:6333"
|
||||
qdrant_collection: str = "memories"
|
||||
embedding_model: str = "snowflake-arctic-embed2"
|
||||
# Removed system_token_budget - system prompt is never truncated
|
||||
semantic_token_budget: int = 25000
|
||||
context_token_budget: int = 22000
|
||||
semantic_search_turns: int = 2
|
||||
semantic_score_threshold: float = 0.6 # Score threshold for semantic search
|
||||
run_time: str = "02:00" # Daily curator time
|
||||
full_run_time: str = "03:00" # Monthly full curator time
|
||||
full_run_day: int = 1 # Day of month for full run (1st)
|
||||
curator_model: str = "gpt-oss:120b"
|
||||
debug: bool = False
|
||||
cloud: CloudConfig = field(default_factory=CloudConfig)
|
||||
|
||||
@property
|
||||
def vector_size(self) -> int:
|
||||
"""Get vector size based on embedding model."""
|
||||
for model_name, dims in EMBEDDING_DIMS.items():
|
||||
if model_name in self.embedding_model:
|
||||
return dims
|
||||
return 1024
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_path: str = None):
|
||||
"""Load config from TOML file.
|
||||
|
||||
Search order:
|
||||
1. Explicit config_path argument
|
||||
2. VERA_CONFIG_DIR/config.toml
|
||||
3. /app/config/config.toml
|
||||
4. config.toml in app root (backward compatibility)
|
||||
"""
|
||||
if config_path is None:
|
||||
# Try config directory first
|
||||
config_path = CONFIG_DIR / "config.toml"
|
||||
if not config_path.exists():
|
||||
# Fall back to app root (backward compatibility)
|
||||
config_path = Path(__file__).parent.parent / "config.toml"
|
||||
else:
|
||||
config_path = Path(config_path)
|
||||
|
||||
config = cls()
|
||||
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
data = toml.load(f)
|
||||
|
||||
if "general" in data:
|
||||
config.ollama_host = data["general"].get("ollama_host", config.ollama_host)
|
||||
config.qdrant_host = data["general"].get("qdrant_host", config.qdrant_host)
|
||||
config.qdrant_collection = data["general"].get("qdrant_collection", config.qdrant_collection)
|
||||
config.embedding_model = data["general"].get("embedding_model", config.embedding_model)
|
||||
config.debug = data["general"].get("debug", config.debug)
|
||||
|
||||
if "layers" in data:
|
||||
# Note: system_token_budget is ignored (system prompt never truncated)
|
||||
config.semantic_token_budget = data["layers"].get("semantic_token_budget", config.semantic_token_budget)
|
||||
config.context_token_budget = data["layers"].get("context_token_budget", config.context_token_budget)
|
||||
config.semantic_search_turns = data["layers"].get("semantic_search_turns", config.semantic_search_turns)
|
||||
config.semantic_score_threshold = data["layers"].get("semantic_score_threshold", config.semantic_score_threshold)
|
||||
|
||||
if "curator" in data:
|
||||
config.run_time = data["curator"].get("run_time", config.run_time)
|
||||
config.full_run_time = data["curator"].get("full_run_time", config.full_run_time)
|
||||
config.full_run_day = data["curator"].get("full_run_day", config.full_run_day)
|
||||
config.curator_model = data["curator"].get("curator_model", config.curator_model)
|
||||
|
||||
if "cloud" in data:
|
||||
cloud_data = data["cloud"]
|
||||
config.cloud = CloudConfig(
|
||||
enabled=cloud_data.get("enabled", False),
|
||||
api_base=cloud_data.get("api_base", ""),
|
||||
api_key_env=cloud_data.get("api_key_env", "OPENROUTER_API_KEY"),
|
||||
models=cloud_data.get("models", {})
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
config = Config.load()
|
||||
208
app/context_handler.py
Normal file
208
app/context_handler.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Context handler - builds 4-layer context for every request."""
|
||||
import httpx
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from .config import Config
|
||||
from .qdrant_service import QdrantService
|
||||
from .utils import count_tokens, truncate_by_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextHandler:
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.qdrant = QdrantService(
|
||||
host=config.qdrant_host,
|
||||
collection=config.qdrant_collection,
|
||||
embedding_model=config.embedding_model,
|
||||
ollama_host=config.ollama_host
|
||||
)
|
||||
self.system_prompt = self._load_system_prompt()
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""Load system prompt from static/systemprompt.md."""
|
||||
try:
|
||||
path = Path(__file__).parent.parent / "static" / "systemprompt.md"
|
||||
return path.read_text().strip()
|
||||
except FileNotFoundError:
|
||||
logger.error("systemprompt.md not found - required file")
|
||||
raise
|
||||
|
||||
async def process(self, messages: List[Dict], model: str, stream: bool = False) -> Dict:
|
||||
"""Process chat request through 4-layer context."""
|
||||
# Get user question (last user message)
|
||||
user_question = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
user_question = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Get messages for semantic search (last N turns)
|
||||
search_messages = []
|
||||
for msg in messages[-self.config.semantic_search_turns:]:
|
||||
if msg.get("role") in ("user", "assistant"):
|
||||
search_messages.append(msg.get("content", ""))
|
||||
|
||||
# Build the 4-layer context messages
|
||||
context_messages = await self.build_context_messages(
|
||||
incoming_system=next((m for m in messages if m.get("role") == "system"), None),
|
||||
user_question=user_question,
|
||||
search_context=" ".join(search_messages)
|
||||
)
|
||||
|
||||
# Forward to Ollama
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.config.ollama_host}/api/chat",
|
||||
json={"model": model, "messages": context_messages, "stream": stream}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
# Store the Q&A turn in Qdrant
|
||||
assistant_msg = result.get("message", {}).get("content", "")
|
||||
await self.qdrant.store_qa_turn(user_question, assistant_msg)
|
||||
|
||||
return result
|
||||
|
||||
def _parse_curated_turn(self, text: str) -> List[Dict]:
|
||||
"""Parse a curated turn into alternating user/assistant messages.
|
||||
|
||||
Input format:
|
||||
User: [question]
|
||||
Assistant: [answer]
|
||||
Timestamp: ISO datetime
|
||||
|
||||
Returns list of message dicts with role and content.
|
||||
"""
|
||||
messages = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
current_role = None
|
||||
current_content = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("User:"):
|
||||
# Save previous content if exists
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
current_role = "user"
|
||||
current_content = [line[5:].strip()] # Remove "User:" prefix
|
||||
elif line.startswith("Assistant:"):
|
||||
# Save previous content if exists
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
current_role = "assistant"
|
||||
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
|
||||
elif line.startswith("Timestamp:"):
|
||||
# Ignore timestamp line
|
||||
continue
|
||||
elif current_role:
|
||||
# Continuation of current message
|
||||
current_content.append(line)
|
||||
|
||||
# Save last message
|
||||
if current_role and current_content:
|
||||
messages.append({
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip()
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
async def build_context_messages(self, incoming_system: Optional[Dict], user_question: str, search_context: str) -> List[Dict]:
|
||||
"""Build 4-layer context messages array."""
|
||||
messages = []
|
||||
token_budget = {
|
||||
"semantic": self.config.semantic_token_budget,
|
||||
"context": self.config.context_token_budget
|
||||
}
|
||||
|
||||
# === LAYER 1: System Prompt (pass through unchanged) ===
|
||||
# DO NOT truncate - preserve OpenClaw's system prompt entirely
|
||||
system_content = ""
|
||||
if incoming_system:
|
||||
system_content = incoming_system.get("content", "")
|
||||
logger.info(f"System layer: preserved incoming system {len(system_content)} chars, {count_tokens(system_content)} tokens")
|
||||
|
||||
# Add Vera context info if present (small, just metadata)
|
||||
if self.system_prompt.strip():
|
||||
system_content += "\n\n" + self.system_prompt
|
||||
logger.info(f"System layer: added vera context {len(self.system_prompt)} chars")
|
||||
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# === LAYER 2: Semantic Layer (curated memories) ===
|
||||
# Search for curated blocks only
|
||||
semantic_results = await self.qdrant.semantic_search(
|
||||
query=search_context if search_context else user_question,
|
||||
limit=20,
|
||||
score_threshold=self.config.semantic_score_threshold,
|
||||
entry_type="curated"
|
||||
)
|
||||
|
||||
# Parse curated turns into alternating user/assistant messages
|
||||
semantic_messages = []
|
||||
semantic_tokens_used = 0
|
||||
|
||||
for result in semantic_results:
|
||||
payload = result.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text:
|
||||
parsed = self._parse_curated_turn(text)
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
if semantic_tokens_used + msg_tokens <= token_budget["semantic"]:
|
||||
semantic_messages.append(msg)
|
||||
semantic_tokens_used += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add parsed messages to context
|
||||
for msg in semantic_messages:
|
||||
messages.append(msg)
|
||||
|
||||
if semantic_messages:
|
||||
logger.info(f"Semantic layer: {len(semantic_messages)} messages, ~{semantic_tokens_used} tokens")
|
||||
|
||||
# === LAYER 3: Context Layer (recent turns) ===
|
||||
recent_turns = await self.qdrant.get_recent_turns(limit=50)
|
||||
|
||||
context_messages_parsed = []
|
||||
context_tokens_used = 0
|
||||
|
||||
for turn in reversed(recent_turns): # Oldest first
|
||||
payload = turn.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
entry_type = payload.get("type", "raw")
|
||||
|
||||
if text:
|
||||
# Parse turn into messages
|
||||
parsed = self._parse_curated_turn(text)
|
||||
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
if context_tokens_used + msg_tokens <= token_budget["context"]:
|
||||
context_messages_parsed.append(msg)
|
||||
context_tokens_used += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
for msg in context_messages_parsed:
|
||||
messages.append(msg)
|
||||
|
||||
if context_messages_parsed:
|
||||
logger.info(f"Context layer: {len(context_messages_parsed)} messages, ~{context_tokens_used} tokens")
|
||||
|
||||
# === LAYER 4: Current Question ===
|
||||
messages.append({"role": "user", "content": user_question})
|
||||
|
||||
return messages
|
||||
266
app/curator.py
Normal file
266
app/curator.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Memory curator - runs daily (recent 24h) and monthly (full DB) to clean and maintain memory database.
|
||||
|
||||
Creates INDIVIDUAL cleaned turns (one per raw turn), not merged summaries.
|
||||
Parses JSON response from curator_prompt.md format.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
import json
|
||||
import re
|
||||
|
||||
from .qdrant_service import QdrantService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configurable prompts directory (can be overridden via environment)
|
||||
PROMPTS_DIR = Path(os.environ.get("VERA_PROMPTS_DIR", "/app/prompts"))
|
||||
STATIC_DIR = Path(os.environ.get("VERA_STATIC_DIR", "/app/static"))
|
||||
|
||||
|
||||
def load_curator_prompt() -> str:
|
||||
"""Load curator prompt from prompts directory."""
|
||||
# Try prompts directory first, then static for backward compatibility
|
||||
prompts_path = PROMPTS_DIR / "curator_prompt.md"
|
||||
static_path = STATIC_DIR / "curator_prompt.md"
|
||||
|
||||
if prompts_path.exists():
|
||||
return prompts_path.read_text().strip()
|
||||
elif static_path.exists():
|
||||
return static_path.read_text().strip()
|
||||
else:
|
||||
raise FileNotFoundError(f"curator_prompt.md not found in {PROMPTS_DIR} or {STATIC_DIR}")
|
||||
|
||||
|
||||
class Curator:
|
||||
def __init__(self, qdrant_service: QdrantService, model: str = "gpt-oss:120b", ollama_host: str = "http://10.0.0.10:11434"):
|
||||
self.qdrant = qdrant_service
|
||||
self.model = model
|
||||
self.ollama_host = ollama_host
|
||||
self.curator_prompt = load_curator_prompt()
|
||||
|
||||
async def run(self, full: bool = False):
|
||||
"""Run the curation process.
|
||||
|
||||
Args:
|
||||
full: If True, process ALL raw memories (monthly full run).
|
||||
If False, process only recent 24h (daily run).
|
||||
"""
|
||||
logger.info(f"Starting memory curation (full={full})...")
|
||||
try:
|
||||
current_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
|
||||
# Get all memories (async)
|
||||
points, _ = await self.qdrant.client.scroll(
|
||||
collection_name=self.qdrant.collection,
|
||||
limit=10000,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
memories = []
|
||||
for point in points:
|
||||
payload = point.payload or {}
|
||||
memories.append({
|
||||
"id": str(point.id),
|
||||
"text": payload.get("text", ""),
|
||||
"type": payload.get("type", "raw"),
|
||||
"timestamp": payload.get("timestamp", ""),
|
||||
"payload": payload
|
||||
})
|
||||
|
||||
raw_memories = [m for m in memories if m["type"] == "raw"]
|
||||
curated_memories = [m for m in memories if m["type"] == "curated"]
|
||||
|
||||
logger.info(f"Found {len(raw_memories)} raw, {len(curated_memories)} curated")
|
||||
|
||||
# Filter by time for daily runs, process all for full runs
|
||||
if full:
|
||||
# Monthly full run: process ALL raw memories
|
||||
recent_raw = raw_memories
|
||||
logger.info(f"FULL RUN: Processing all {len(recent_raw)} raw memories")
|
||||
else:
|
||||
# Daily run: process only recent 24h
|
||||
recent_raw = [m for m in raw_memories if self._is_recent(m, hours=24)]
|
||||
logger.info(f"DAILY RUN: Processing {len(recent_raw)} recent raw memories")
|
||||
|
||||
existing_sample = curated_memories[-50:] if len(curated_memories) > 50 else curated_memories
|
||||
|
||||
if not recent_raw:
|
||||
logger.info("No raw memories to process")
|
||||
return
|
||||
|
||||
raw_turns_text = self._format_raw_turns(recent_raw)
|
||||
existing_text = self._format_existing_memories(existing_sample)
|
||||
|
||||
prompt = self.curator_prompt.replace("{CURRENT_DATE}", current_date)
|
||||
full_prompt = f"""{prompt}
|
||||
|
||||
## {'All' if full else 'Recent'} Raw Turns ({'full database' if full else 'last 24 hours'}):
|
||||
{raw_turns_text}
|
||||
|
||||
## Existing Memories (sample):
|
||||
{existing_text}
|
||||
|
||||
Remember: Respond with ONLY valid JSON. No markdown, no explanations, just the JSON object."""
|
||||
|
||||
logger.info(f"Sending {len(recent_raw)} raw turns to LLM...")
|
||||
response_text = await self._call_llm(full_prompt)
|
||||
|
||||
result = self._parse_json_response(response_text)
|
||||
|
||||
if not result:
|
||||
logger.error("Failed to parse JSON response from LLM")
|
||||
return
|
||||
|
||||
new_turns = result.get("new_curated_turns", [])
|
||||
permanent_rules = result.get("permanent_rules", [])
|
||||
deletions = result.get("deletions", [])
|
||||
summary = result.get("summary", "")
|
||||
|
||||
logger.info(f"Parsed: {len(new_turns)} turns, {len(permanent_rules)} rules, {len(deletions)} deletions")
|
||||
logger.info(f"Summary: {summary}")
|
||||
|
||||
for turn in new_turns:
|
||||
content = turn.get("content", "")
|
||||
if content:
|
||||
await self.qdrant.store_turn(
|
||||
role="curated",
|
||||
content=content,
|
||||
entry_type="curated"
|
||||
)
|
||||
logger.info(f"Stored curated turn: {content[:100]}...")
|
||||
|
||||
for rule in permanent_rules:
|
||||
rule_text = rule.get("rule", "")
|
||||
target_file = rule.get("target_file", "systemprompt.md")
|
||||
if rule_text:
|
||||
await self._append_rule_to_file(target_file, rule_text)
|
||||
logger.info(f"Appended rule to {target_file}: {rule_text[:80]}...")
|
||||
|
||||
if deletions:
|
||||
valid_deletions = [d for d in deletions if d in [m["id"] for m in memories]]
|
||||
if valid_deletions:
|
||||
await self.qdrant.delete_points(valid_deletions)
|
||||
logger.info(f"Deleted {len(valid_deletions)} points")
|
||||
|
||||
raw_ids_to_delete = [m["id"] for m in recent_raw]
|
||||
if raw_ids_to_delete:
|
||||
await self.qdrant.delete_points(raw_ids_to_delete)
|
||||
logger.info(f"Deleted {len(raw_ids_to_delete)} processed raw memories")
|
||||
|
||||
logger.info(f"Memory curation completed successfully (full={full})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during curation: {e}")
|
||||
raise
|
||||
|
||||
async def run_full(self):
|
||||
"""Run full curation (all raw memories). Convenience method."""
|
||||
await self.run(full=True)
|
||||
|
||||
async def run_daily(self):
|
||||
"""Run daily curation (recent 24h only). Convenience method."""
|
||||
await self.run(full=False)
|
||||
|
||||
def _is_recent(self, memory: Dict, hours: int = 24) -> bool:
|
||||
"""Check if memory is within the specified hours."""
|
||||
timestamp = memory.get("timestamp", "")
|
||||
if not timestamp:
|
||||
return True
|
||||
try:
|
||||
mem_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
return mem_time.replace(tzinfo=None) > cutoff
|
||||
except:
|
||||
return True
|
||||
|
||||
def _format_raw_turns(self, turns: List[Dict]) -> str:
|
||||
"""Format raw turns for the LLM prompt."""
|
||||
formatted = []
|
||||
for i, turn in enumerate(turns, 1):
|
||||
text = turn.get("text", "")
|
||||
formatted.append(f"--- RAW TURN {i} (ID: {turn.get('id', 'unknown')}) ---\n{text}\n")
|
||||
return "\n".join(formatted)
|
||||
|
||||
def _format_existing_memories(self, memories: List[Dict]) -> str:
|
||||
"""Format existing curated memories for the LLM prompt."""
|
||||
if not memories:
|
||||
return "No existing curated memories."
|
||||
formatted = []
|
||||
for i, mem in enumerate(memories[-20:], 1):
|
||||
text = mem.get("text", "")
|
||||
formatted.append(f"{text}\n")
|
||||
return "\n".join(formatted)
|
||||
|
||||
async def _call_llm(self, prompt: str) -> str:
|
||||
"""Call Ollama LLM with the prompt."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.ollama_host}/api/generate",
|
||||
json={
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.3,
|
||||
"num_predict": 8192
|
||||
}
|
||||
}
|
||||
)
|
||||
result = response.json()
|
||||
return result.get("response", "")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to call LLM: {e}")
|
||||
return ""
|
||||
|
||||
def _parse_json_response(self, response: str) -> Optional[Dict]:
|
||||
"""Parse JSON from LLM response."""
|
||||
if not response:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
start = response.find("{")
|
||||
end = response.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
return json.loads(response[start:end])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', response)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
logger.error(f"Could not parse JSON: {response[:500]}...")
|
||||
return None
|
||||
|
||||
async def _append_rule_to_file(self, filename: str, rule: str):
|
||||
"""Append a permanent rule to a prompts file."""
|
||||
# Try prompts directory first, then static for backward compatibility
|
||||
prompts_path = PROMPTS_DIR / filename
|
||||
static_path = STATIC_DIR / filename
|
||||
|
||||
# Use whichever directory is writable
|
||||
target_path = prompts_path if prompts_path.parent.exists() else static_path
|
||||
|
||||
try:
|
||||
# Ensure parent directory exists
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(target_path, "a") as f:
|
||||
f.write(f"\n{rule}\n")
|
||||
logger.info(f"Appended rule to {target_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to append rule to {filename}: {e}")
|
||||
156
app/main.py
Normal file
156
app/main.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# app/main.py
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import httpx
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from .config import config
|
||||
from .singleton import get_qdrant_service
|
||||
from .proxy_handler import handle_chat, forward_to_ollama, handle_chat_non_streaming
|
||||
from .curator import Curator
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
curator = None
|
||||
|
||||
|
||||
async def run_curator():
|
||||
"""Scheduled daily curator job (recent 24h)."""
|
||||
global curator
|
||||
logger.info("Starting daily memory curation...")
|
||||
try:
|
||||
await curator.run_daily()
|
||||
logger.info("Daily memory curation completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Daily memory curation failed: {e}")
|
||||
|
||||
|
||||
async def run_curator_full():
|
||||
"""Scheduled monthly curator job (full database)."""
|
||||
global curator
|
||||
logger.info("Starting monthly full memory curation...")
|
||||
try:
|
||||
await curator.run_full()
|
||||
logger.info("Monthly full memory curation completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Monthly full memory curation failed: {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan - startup and shutdown."""
|
||||
global curator
|
||||
|
||||
logger.info("Starting Vera-AI...")
|
||||
|
||||
# Initialize singleton QdrantService
|
||||
qdrant_service = get_qdrant_service()
|
||||
await qdrant_service._ensure_collection()
|
||||
|
||||
# Initialize curator with singleton
|
||||
curator = Curator(
|
||||
qdrant_service=qdrant_service,
|
||||
model=config.curator_model,
|
||||
ollama_host=config.ollama_host
|
||||
)
|
||||
|
||||
# Schedule daily curator (recent 24h)
|
||||
hour, minute = map(int, config.run_time.split(":"))
|
||||
scheduler.add_job(run_curator, "cron", hour=hour, minute=minute, id="daily_curator")
|
||||
logger.info(f"Daily curator scheduled at {config.run_time}")
|
||||
|
||||
# Schedule monthly full curator (all raw memories)
|
||||
full_hour, full_minute = map(int, config.full_run_time.split(":"))
|
||||
scheduler.add_job(
|
||||
run_curator_full,
|
||||
"cron",
|
||||
day=config.full_run_day,
|
||||
hour=full_hour,
|
||||
minute=full_minute,
|
||||
id="monthly_curator"
|
||||
)
|
||||
logger.info(f"Monthly full curator scheduled on day {config.full_run_day} at {config.full_run_time}")
|
||||
|
||||
scheduler.start()
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down Vera-AI...")
|
||||
scheduler.shutdown()
|
||||
await qdrant_service.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
ollama_status = "unreachable"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(f"{config.ollama_host}/api/tags")
|
||||
if resp.status_code == 200:
|
||||
ollama_status = "reachable"
|
||||
except: pass
|
||||
return {"status": "ok", "ollama": ollama_status}
|
||||
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def api_tags():
|
||||
"""Proxy to Ollama /api/tags with cloud model injection."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(f"{config.ollama_host}/api/tags")
|
||||
data = resp.json()
|
||||
|
||||
if config.cloud.enabled and config.cloud.models:
|
||||
for name in config.cloud.models.keys():
|
||||
data["models"].append({
|
||||
"name": name,
|
||||
"modified_at": "2026-03-25T00:00:00Z",
|
||||
"size": 0,
|
||||
"digest": "cloud",
|
||||
"details": {"family": "cloud"}
|
||||
})
|
||||
return JSONResponse(content=data)
|
||||
|
||||
|
||||
@app.api_route("/api/{path:path}", methods=["GET", "POST", "DELETE"])
|
||||
async def proxy_all(request: Request, path: str):
|
||||
if path == "chat":
|
||||
body = await request.json()
|
||||
is_stream = body.get("stream", True)
|
||||
|
||||
if is_stream:
|
||||
return await handle_chat(request)
|
||||
else:
|
||||
return await handle_chat_non_streaming(body)
|
||||
else:
|
||||
resp = await forward_to_ollama(request, f"/api/{path}")
|
||||
return StreamingResponse(
|
||||
resp.aiter_bytes(),
|
||||
status_code=resp.status_code,
|
||||
headers=dict(resp.headers),
|
||||
media_type=resp.headers.get("content-type")
|
||||
)
|
||||
|
||||
|
||||
@app.post("/curator/run")
|
||||
async def trigger_curator(full: bool = False):
|
||||
"""Manually trigger curator.
|
||||
|
||||
Args:
|
||||
full: If True, run full curation (all raw memories).
|
||||
If False (default), run daily curation (recent 24h).
|
||||
"""
|
||||
if full:
|
||||
await run_curator_full()
|
||||
return {"status": "full curation completed"}
|
||||
else:
|
||||
await run_curator()
|
||||
return {"status": "daily curation completed"}
|
||||
215
app/proxy_handler.py
Normal file
215
app/proxy_handler.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# app/proxy_handler.py
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import httpx
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from .config import config
|
||||
from .singleton import get_qdrant_service
|
||||
from .utils import count_tokens, build_augmented_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Debug log directory (configurable via environment)
|
||||
# Logs are written to VERA_LOG_DIR or /app/logs by default
|
||||
DEBUG_LOG_DIR = Path(os.environ.get("VERA_LOG_DIR", "/app/logs"))
|
||||
|
||||
|
||||
def clean_message_content(content: str) -> str:
|
||||
"""Strip [Memory context] wrapper and extract actual user message."""
|
||||
if not content:
|
||||
return content
|
||||
|
||||
# Check for OpenJarvis/OpenClaw wrapper
|
||||
wrapper_match = re.search(
|
||||
r'\[Memory context\].*?- user_msg:\s*(.+?)(?:\n\n|\Z)',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if wrapper_match:
|
||||
return wrapper_match.group(1).strip()
|
||||
|
||||
# Also strip timestamp prefixes if present
|
||||
ts_match = re.match(r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*', content)
|
||||
if ts_match:
|
||||
return content[ts_match.end():].strip()
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def debug_log(category: str, message: str, data: dict = None):
|
||||
"""Append a debug entry to the daily debug log if debug mode is enabled.
|
||||
|
||||
Logs are written to VERA_LOG_DIR/debug_YYYY-MM-DD.log
|
||||
This ensures logs are persisted and easily accessible.
|
||||
"""
|
||||
if not config.debug:
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
# Create logs directory
|
||||
log_dir = DEBUG_LOG_DIR
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
log_path = log_dir / f"debug_{today}.log"
|
||||
|
||||
entry = {
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"category": category,
|
||||
"message": message
|
||||
}
|
||||
if data:
|
||||
entry["data"] = data
|
||||
|
||||
with open(log_path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
|
||||
async def handle_chat_non_streaming(body: dict):
|
||||
"""Handle non-streaming chat request."""
|
||||
incoming_messages = body.get("messages", [])
|
||||
model = body.get("model", "")
|
||||
|
||||
debug_log("INPUT", "Non-streaming chat request", {"messages": incoming_messages})
|
||||
|
||||
# Clean messages
|
||||
cleaned_messages = []
|
||||
for msg in incoming_messages:
|
||||
if msg.get("role") == "user":
|
||||
cleaned_content = clean_message_content(msg.get("content", ""))
|
||||
cleaned_messages.append({"role": "user", "content": cleaned_content})
|
||||
else:
|
||||
cleaned_messages.append(msg)
|
||||
|
||||
# Build augmented messages
|
||||
augmented_messages = await build_augmented_messages(cleaned_messages)
|
||||
|
||||
debug_log("THOUGHT", "Built augmented messages", {"augmented_count": len(augmented_messages)})
|
||||
|
||||
# Forward to Ollama
|
||||
forwarded_body = body.copy()
|
||||
forwarded_body["messages"] = augmented_messages
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(f"{config.ollama_host}/api/chat", json=forwarded_body)
|
||||
result = resp.json()
|
||||
|
||||
debug_log("OUTPUT", "LLM non-streaming response", {"response": result})
|
||||
|
||||
# Store the Q&A turn
|
||||
user_question = ""
|
||||
for msg in reversed(incoming_messages):
|
||||
if msg.get("role") == "user":
|
||||
user_question = clean_message_content(msg.get("content", ""))
|
||||
break
|
||||
|
||||
assistant_answer = result.get("message", {}).get("content", "")
|
||||
|
||||
if user_question and assistant_answer:
|
||||
qdrant_service = get_qdrant_service()
|
||||
try:
|
||||
result_id = await qdrant_service.store_qa_turn(user_question, assistant_answer)
|
||||
debug_log("STORAGE", "Non-streaming Q&A stored", {"question": user_question, "answer": assistant_answer})
|
||||
except Exception as e:
|
||||
logger.error(f"[STORE] FAILED: {e}")
|
||||
|
||||
result["model"] = model
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
async def handle_chat(request: Request):
|
||||
"""Handle streaming chat request."""
|
||||
body = await request.json()
|
||||
incoming_messages = body.get("messages", [])
|
||||
model = body.get("model", "")
|
||||
|
||||
debug_log("INPUT", "Streaming chat request", {"messages": incoming_messages})
|
||||
|
||||
# Clean messages
|
||||
cleaned_messages = []
|
||||
for msg in incoming_messages:
|
||||
if msg.get("role") == "user":
|
||||
cleaned_content = clean_message_content(msg.get("content", ""))
|
||||
cleaned_messages.append({"role": "user", "content": cleaned_content})
|
||||
else:
|
||||
cleaned_messages.append(msg)
|
||||
|
||||
# Build augmented messages
|
||||
augmented_messages = await build_augmented_messages(cleaned_messages)
|
||||
|
||||
debug_log("THOUGHT", "Built augmented messages for streaming", {
|
||||
"original_count": len(incoming_messages),
|
||||
"augmented_count": len(augmented_messages)
|
||||
})
|
||||
|
||||
# Forward to Ollama with streaming
|
||||
forwarded_body = body.copy()
|
||||
forwarded_body["messages"] = augmented_messages
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("Content-Length", None)
|
||||
headers.pop("content-type", None)
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
async def stream_response():
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(
|
||||
f"{config.ollama_host}/api/chat",
|
||||
json=forwarded_body,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
for line in chunk.decode().strip().splitlines():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "message" in data and "content" in data["message"]:
|
||||
full_content += data["message"]["content"]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
debug_log("OUTPUT", "LLM streaming response complete", {
|
||||
"content_length": len(full_content)
|
||||
})
|
||||
|
||||
# Store Q&A turn
|
||||
user_question = ""
|
||||
for msg in reversed(incoming_messages):
|
||||
if msg.get("role") == "user":
|
||||
user_question = clean_message_content(msg.get("content", ""))
|
||||
break
|
||||
|
||||
if user_question and full_content:
|
||||
qdrant_service = get_qdrant_service()
|
||||
try:
|
||||
result_id = await qdrant_service.store_qa_turn(user_question, full_content)
|
||||
logger.info(f"[STORE] Success! ID={result_id[:8]}, Q={len(user_question)} chars")
|
||||
except Exception as e:
|
||||
logger.error(f"[STORE] FAILED: {type(e).__name__}: {e}")
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
async def forward_to_ollama(request: Request, path: str):
|
||||
"""Forward request to Ollama transparently."""
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("Content-Length", None)
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.request(
|
||||
method=request.method,
|
||||
url=f"{config.ollama_host}{path}",
|
||||
content=body,
|
||||
headers=headers
|
||||
)
|
||||
return resp
|
||||
156
app/qdrant_service.py
Normal file
156
app/qdrant_service.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Qdrant service for memory storage - ASYNC VERSION."""
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantService:
|
||||
def __init__(self, host: str, collection: str, embedding_model: str, vector_size: int = 1024, ollama_host: str = "http://10.0.0.10:11434"):
|
||||
self.host = host
|
||||
self.collection = collection
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_size = vector_size
|
||||
self.ollama_host = ollama_host
|
||||
# Use ASYNC client
|
||||
self.client = AsyncQdrantClient(url=host)
|
||||
self._collection_ensured = False
|
||||
|
||||
async def _ensure_collection(self):
|
||||
"""Ensure collection exists - lazy initialization."""
|
||||
if self._collection_ensured:
|
||||
return
|
||||
try:
|
||||
await self.client.get_collection(self.collection)
|
||||
logger.info(f"Collection {self.collection} exists")
|
||||
except Exception:
|
||||
await self.client.create_collection(
|
||||
collection_name=self.collection,
|
||||
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
|
||||
)
|
||||
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
|
||||
self._collection_ensured = True
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
"""Get embedding from Ollama."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.ollama_host}/api/embeddings",
|
||||
json={"model": self.embedding_model, "prompt": text},
|
||||
timeout=30.0
|
||||
)
|
||||
result = response.json()
|
||||
return result["embedding"]
|
||||
|
||||
async def store_turn(self, role: str, content: str, entry_type: str = "raw", topic: Optional[str] = None, metadata: Optional[Dict] = None) -> str:
|
||||
"""Store a turn in Qdrant with proper payload format."""
|
||||
await self._ensure_collection()
|
||||
|
||||
point_id = str(uuid.uuid4())
|
||||
embedding = await self.get_embedding(content)
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
text = content
|
||||
if role == "user":
|
||||
text = f"User: {content}"
|
||||
elif role == "assistant":
|
||||
text = f"Assistant: {content}"
|
||||
elif role == "curated":
|
||||
text = content
|
||||
|
||||
payload = {
|
||||
"type": entry_type,
|
||||
"text": text,
|
||||
"timestamp": timestamp,
|
||||
"role": role,
|
||||
"content": content
|
||||
}
|
||||
if topic:
|
||||
payload["topic"] = topic
|
||||
if metadata:
|
||||
payload.update(metadata)
|
||||
|
||||
await self.client.upsert(
|
||||
collection_name=self.collection,
|
||||
points=[PointStruct(id=point_id, vector=embedding, payload=payload)]
|
||||
)
|
||||
return point_id
|
||||
|
||||
async def store_qa_turn(self, user_question: str, assistant_answer: str) -> str:
|
||||
"""Store a complete Q&A turn as one document."""
|
||||
await self._ensure_collection()
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
text = f"User: {user_question}\nAssistant: {assistant_answer}\nTimestamp: {timestamp}"
|
||||
|
||||
point_id = str(uuid.uuid4())
|
||||
embedding = await self.get_embedding(text)
|
||||
|
||||
payload = {
|
||||
"type": "raw",
|
||||
"text": text,
|
||||
"timestamp": timestamp,
|
||||
"role": "qa",
|
||||
"content": text
|
||||
}
|
||||
|
||||
await self.client.upsert(
|
||||
collection_name=self.collection,
|
||||
points=[PointStruct(id=point_id, vector=embedding, payload=payload)]
|
||||
)
|
||||
return point_id
|
||||
|
||||
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]:
|
||||
"""Semantic search for relevant turns, filtered by type."""
|
||||
await self._ensure_collection()
|
||||
|
||||
embedding = await self.get_embedding(query)
|
||||
|
||||
results = await self.client.query_points(
|
||||
collection_name=self.collection,
|
||||
query=embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="type", match=MatchValue(value=entry_type))]
|
||||
)
|
||||
)
|
||||
|
||||
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
|
||||
|
||||
async def get_recent_turns(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get recent turns from Qdrant (both raw and curated)."""
|
||||
await self._ensure_collection()
|
||||
|
||||
points, _ = await self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
limit=limit * 2,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Sort by timestamp descending
|
||||
sorted_points = sorted(
|
||||
points,
|
||||
key=lambda p: p.payload.get("timestamp", ""),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]]
|
||||
|
||||
async def delete_points(self, point_ids: List[str]) -> None:
|
||||
"""Delete points by ID."""
|
||||
from qdrant_client.models import PointIdsList
|
||||
await self.client.delete(
|
||||
collection_name=self.collection,
|
||||
points_selector=PointIdsList(points=point_ids)
|
||||
)
|
||||
logger.info(f"Deleted {len(point_ids)} points")
|
||||
|
||||
async def close(self):
|
||||
"""Close the async client."""
|
||||
await self.client.close()
|
||||
19
app/singleton.py
Normal file
19
app/singleton.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Global singleton instances for Vera-AI."""
|
||||
from .qdrant_service import QdrantService
|
||||
from .config import config
|
||||
|
||||
_qdrant_service: QdrantService = None
|
||||
|
||||
|
||||
def get_qdrant_service() -> QdrantService:
|
||||
"""Get or create the global QdrantService singleton."""
|
||||
global _qdrant_service
|
||||
if _qdrant_service is None:
|
||||
_qdrant_service = QdrantService(
|
||||
host=config.qdrant_host,
|
||||
collection=config.qdrant_collection,
|
||||
embedding_model=config.embedding_model,
|
||||
vector_size=config.vector_size,
|
||||
ollama_host=config.ollama_host
|
||||
)
|
||||
return _qdrant_service
|
||||
203
app/utils.py
Normal file
203
app/utils.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Utility functions for vera-ai."""
|
||||
from .config import config
|
||||
import tiktoken
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
# Use cl100k_base encoding (GPT-4 compatible)
|
||||
ENCODING = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Configurable paths (can be overridden via environment)
|
||||
PROMPTS_DIR = Path(os.environ.get("VERA_PROMPTS_DIR", "/app/prompts"))
|
||||
STATIC_DIR = Path(os.environ.get("VERA_STATIC_DIR", "/app/static"))
|
||||
|
||||
# Global qdrant_service instance for utils
|
||||
_qdrant_service = None
|
||||
|
||||
def get_qdrant_service():
|
||||
"""Get or create the QdrantService singleton."""
|
||||
global _qdrant_service
|
||||
if _qdrant_service is None:
|
||||
from .config import config
|
||||
from .qdrant_service import QdrantService
|
||||
_qdrant_service = QdrantService(
|
||||
host=config.qdrant_host,
|
||||
collection=config.qdrant_collection,
|
||||
embedding_model=config.embedding_model,
|
||||
vector_size=config.vector_size,
|
||||
ollama_host=config.ollama_host
|
||||
)
|
||||
return _qdrant_service
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count tokens in text."""
|
||||
if not text:
|
||||
return 0
|
||||
return len(ENCODING.encode(text))
|
||||
|
||||
def count_messages_tokens(messages: List[dict]) -> int:
|
||||
"""Count total tokens in messages."""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
if "content" in msg:
|
||||
total += count_tokens(msg["content"])
|
||||
return total
|
||||
|
||||
def truncate_by_tokens(text: str, max_tokens: int) -> str:
|
||||
"""Truncate text to fit within token budget."""
|
||||
if not text:
|
||||
return text
|
||||
tokens = ENCODING.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
return ENCODING.decode(tokens[:max_tokens])
|
||||
|
||||
def filter_memories_by_time(memories: List[Dict], hours: int = 24) -> List[Dict]:
|
||||
"""Filter memories from the last N hours."""
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
filtered = []
|
||||
for mem in memories:
|
||||
ts = mem.get("timestamp")
|
||||
if ts:
|
||||
try:
|
||||
# Parse ISO timestamp
|
||||
if isinstance(ts, str):
|
||||
mem_time = datetime.fromisoformat(ts.replace("Z", "+00:00").replace("+00:00", ""))
|
||||
else:
|
||||
mem_time = ts
|
||||
if mem_time > cutoff:
|
||||
filtered.append(mem)
|
||||
except Exception:
|
||||
# If timestamp parsing fails, include the memory
|
||||
filtered.append(mem)
|
||||
else:
|
||||
# Include memories without timestamp
|
||||
filtered.append(mem)
|
||||
return filtered
|
||||
|
||||
def merge_memories(memories: List[Dict]) -> Dict:
|
||||
"""Merge multiple memories into one combined text."""
|
||||
if not memories:
|
||||
return {"text": "", "ids": []}
|
||||
|
||||
texts = []
|
||||
ids = []
|
||||
for mem in memories:
|
||||
text = mem.get("text", "") or mem.get("content", "")
|
||||
if text:
|
||||
# Include role if available
|
||||
role = mem.get("role", "")
|
||||
if role:
|
||||
texts.append(f"[{role}]: {text}")
|
||||
else:
|
||||
texts.append(text)
|
||||
ids.append(mem.get("id"))
|
||||
|
||||
return {
|
||||
"text": "\n\n".join(texts),
|
||||
"ids": ids
|
||||
}
|
||||
|
||||
def calculate_token_budget(total_budget: int, system_ratio: float = 0.2,
|
||||
semantic_ratio: float = 0.5, context_ratio: float = 0.3) -> Dict[int, int]:
|
||||
"""Calculate token budgets for each layer."""
|
||||
return {
|
||||
"system": int(total_budget * system_ratio),
|
||||
"semantic": int(total_budget * semantic_ratio),
|
||||
"context": int(total_budget * context_ratio)
|
||||
}
|
||||
|
||||
def load_system_prompt() -> str:
|
||||
"""Load system prompt from prompts directory."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try prompts directory first, then static for backward compatibility
|
||||
prompts_path = PROMPTS_DIR / "systemprompt.md"
|
||||
static_path = STATIC_DIR / "systemprompt.md"
|
||||
|
||||
if prompts_path.exists():
|
||||
return prompts_path.read_text().strip()
|
||||
elif static_path.exists():
|
||||
return static_path.read_text().strip()
|
||||
else:
|
||||
logger.warning("systemprompt.md not found")
|
||||
return ""
|
||||
|
||||
|
||||
async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
||||
"""Build 4-layer augmented messages from incoming messages.
|
||||
|
||||
This is a standalone version that can be used by proxy_handler.py.
|
||||
"""
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load system prompt
|
||||
system_prompt = load_system_prompt()
|
||||
|
||||
# Get user question (last user message)
|
||||
user_question = ""
|
||||
for msg in reversed(incoming_messages):
|
||||
if msg.get("role") == "user":
|
||||
user_question = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Get search context (last few turns)
|
||||
search_context = ""
|
||||
for msg in incoming_messages[-6:]:
|
||||
if msg.get("role") in ("user", "assistant"):
|
||||
search_context += msg.get("content", "") + " "
|
||||
|
||||
messages = []
|
||||
|
||||
# === LAYER 1: System Prompt ===
|
||||
system_content = ""
|
||||
for msg in incoming_messages:
|
||||
if msg.get("role") == "system":
|
||||
system_content = msg.get("content", "")
|
||||
break
|
||||
|
||||
if system_prompt:
|
||||
system_content += "\n\n" + system_prompt
|
||||
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# === LAYER 2: Semantic (curated memories) ===
|
||||
qdrant = get_qdrant_service()
|
||||
semantic_results = await qdrant.semantic_search(
|
||||
query=search_context if search_context else user_question,
|
||||
limit=20,
|
||||
score_threshold=config.semantic_score_threshold,
|
||||
entry_type="curated"
|
||||
)
|
||||
|
||||
semantic_tokens = 0
|
||||
for result in semantic_results:
|
||||
payload = result.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text and semantic_tokens < config.semantic_token_budget:
|
||||
messages.append({"role": "user", "content": text}) # Add as context
|
||||
semantic_tokens += count_tokens(text)
|
||||
|
||||
# === LAYER 3: Context (recent turns) ===
|
||||
recent_turns = await qdrant.get_recent_turns(limit=20)
|
||||
|
||||
context_tokens = 0
|
||||
for turn in reversed(recent_turns):
|
||||
payload = turn.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text and context_tokens < config.context_token_budget:
|
||||
messages.append({"role": "user", "content": text}) # Add as context
|
||||
context_tokens += count_tokens(text)
|
||||
|
||||
# === LAYER 4: Current messages (passed through) ===
|
||||
for msg in incoming_messages:
|
||||
if msg.get("role") != "system": # Do not duplicate system
|
||||
messages.append(msg)
|
||||
|
||||
return messages
|
||||
Reference in New Issue
Block a user