feat: semantic search includes raw turns, deduplicate layers, fix recent turn ordering

- Layer 2 semantic search now queries both curated and raw types,
  closing the blind spot for turns past the 50-turn window pre-curation
- Layer 3 skips turns already returned by Layer 2 to avoid duplicate
  context and wasted token budget
- get_recent_turns uses Qdrant OrderBy for server-side timestamp sort
  with payload index; fallback to client-side sort if unavailable
- Bump version to 2.0.4
This commit is contained in:
Vera-AI
2026-04-01 17:43:47 -05:00
parent de7f3a78ab
commit 346f2c26fe
4 changed files with 61 additions and 31 deletions

View File

@@ -68,7 +68,7 @@ async def lifespan(app: FastAPI):
await qdrant_service.close() await qdrant_service.close()
app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan) app = FastAPI(title="Vera-AI", version="2.0.4", lifespan=lifespan)
@app.get("/") @app.get("/")

View File

@@ -1,6 +1,6 @@
"""Qdrant service for memory storage - ASYNC VERSION.""" """Qdrant service for memory storage - ASYNC VERSION."""
from qdrant_client import AsyncQdrantClient from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from datetime import datetime, timezone from datetime import datetime, timezone
import uuid import uuid
@@ -34,6 +34,15 @@ class QdrantService:
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
) )
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}") logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
# Ensure payload index on timestamp for ordered scroll
try:
await self.client.create_payload_index(
collection_name=self.collection,
field_name="timestamp",
field_schema=PayloadSchemaType.KEYWORD
)
except Exception:
pass # Index may already exist
self._collection_ensured = True self._collection_ensured = True
async def get_embedding(self, text: str) -> List[float]: async def get_embedding(self, text: str) -> List[float]:
@@ -105,20 +114,28 @@ class QdrantService:
) )
return point_id return point_id
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]: async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated", entry_types: Optional[List[str]] = None) -> List[Dict]:
"""Semantic search for relevant turns, filtered by type.""" """Semantic search for relevant turns, filtered by type(s)."""
await self._ensure_collection() await self._ensure_collection()
embedding = await self.get_embedding(query) embedding = await self.get_embedding(query)
if entry_types and len(entry_types) > 1:
type_filter = Filter(
should=[FieldCondition(key="type", match=MatchValue(value=t)) for t in entry_types]
)
else:
filter_type = entry_types[0] if entry_types else entry_type
type_filter = Filter(
must=[FieldCondition(key="type", match=MatchValue(value=filter_type))]
)
results = await self.client.query_points( results = await self.client.query_points(
collection_name=self.collection, collection_name=self.collection,
query=embedding, query=embedding,
limit=limit, limit=limit,
score_threshold=score_threshold, score_threshold=score_threshold,
query_filter=Filter( query_filter=type_filter
must=[FieldCondition(key="type", match=MatchValue(value=entry_type))]
)
) )
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points] return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
@@ -126,21 +143,29 @@ class QdrantService:
async def get_recent_turns(self, limit: int = 20) -> List[Dict]: async def get_recent_turns(self, limit: int = 20) -> List[Dict]:
"""Get recent turns from Qdrant (both raw and curated).""" """Get recent turns from Qdrant (both raw and curated)."""
await self._ensure_collection() await self._ensure_collection()
points, _ = await self.client.scroll( try:
collection_name=self.collection, from qdrant_client.models import OrderBy
limit=limit * 2, points, _ = await self.client.scroll(
with_payload=True collection_name=self.collection,
) limit=limit,
with_payload=True,
# Sort by timestamp descending order_by=OrderBy(key="timestamp", direction="desc")
sorted_points = sorted( )
points, except Exception:
key=lambda p: p.payload.get("timestamp", ""), # Fallback: fetch extra points and sort client-side
reverse=True points, _ = await self.client.scroll(
) collection_name=self.collection,
limit=limit * 5,
return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]] with_payload=True
)
points = sorted(
points,
key=lambda p: p.payload.get("timestamp", ""),
reverse=True
)[:limit]
return [{"id": str(p.id), "payload": p.payload} for p in points]
async def delete_points(self, point_ids: List[str]) -> None: async def delete_points(self, point_ids: List[str]) -> None:
"""Delete points by ID.""" """Delete points by ID."""

View File

@@ -213,23 +213,25 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
messages.append({"role": "system", "content": system_content}) messages.append({"role": "system", "content": system_content})
logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens") logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens")
# === LAYER 2: Semantic (curated memories) === # === LAYER 2: Semantic (curated + raw memories) ===
qdrant = get_qdrant_service() qdrant = get_qdrant_service()
semantic_results = await qdrant.semantic_search( semantic_results = await qdrant.semantic_search(
query=search_context if search_context else user_question, query=search_context if search_context else user_question,
limit=20, limit=20,
score_threshold=config.semantic_score_threshold, score_threshold=config.semantic_score_threshold,
entry_type="curated" entry_types=["curated", "raw"]
) )
semantic_messages = [] semantic_messages = []
semantic_tokens_used = 0 semantic_tokens_used = 0
semantic_ids = set()
for result in semantic_results: for result in semantic_results:
semantic_ids.add(result.get("id"))
payload = result.get("payload", {}) payload = result.get("payload", {})
text = payload.get("text", "") text = payload.get("text", "")
if text: if text:
# Parse curated turn into proper user/assistant messages # Parse curated/raw turn into proper user/assistant messages
parsed = parse_curated_turn(text) parsed = parse_curated_turn(text)
for msg in parsed: for msg in parsed:
msg_tokens = count_tokens(msg.get("content", "")) msg_tokens = count_tokens(msg.get("content", ""))
@@ -254,8 +256,10 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
context_messages = [] context_messages = []
context_tokens_used = 0 context_tokens_used = 0
# Process oldest first for chronological order # Process oldest first for chronological order, skip duplicates from Layer 2
for turn in reversed(recent_turns): for turn in reversed(recent_turns):
if turn.get("id") in semantic_ids:
continue
payload = turn.get("payload", {}) payload = turn.get("payload", {})
text = payload.get("text", "") text = payload.get("text", "")
entry_type = payload.get("type", "raw") entry_type = payload.get("type", "raw")

View File

@@ -217,12 +217,13 @@ class TestGetRecentTurns:
mock_point2.id = "new" mock_point2.id = "new"
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"} mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
mock_client.scroll = AsyncMock(return_value=([mock_point1, mock_point2], None)) # OrderBy returns server-sorted results (newest first)
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
results = await svc.get_recent_turns(limit=2) results = await svc.get_recent_turns(limit=2)
assert len(results) == 2 assert len(results) == 2
# Newest first # Newest first (server-sorted via OrderBy)
assert results[0]["id"] == "new" assert results[0]["id"] == "new"
assert results[1]["id"] == "old" assert results[1]["id"] == "old"