feat: semantic search includes raw turns, deduplicate layers, fix recent turn ordering
- Layer 2 semantic search now queries both curated and raw types, closing the blind spot for turns past the 50-turn window pre-curation - Layer 3 skips turns already returned by Layer 2 to avoid duplicate context and wasted token budget - get_recent_turns uses Qdrant OrderBy for server-side timestamp sort with payload index; fallback to client-side sort if unavailable - Bump version to 2.0.4
This commit is contained in:
@@ -68,7 +68,7 @@ async def lifespan(app: FastAPI):
|
||||
await qdrant_service.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan)
|
||||
app = FastAPI(title="Vera-AI", version="2.0.4", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Qdrant service for memory storage - ASYNC VERSION."""
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
@@ -34,6 +34,15 @@ class QdrantService:
|
||||
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
|
||||
)
|
||||
logger.info(f"Created collection {self.collection} with vector size {self.vector_size}")
|
||||
# Ensure payload index on timestamp for ordered scroll
|
||||
try:
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection,
|
||||
field_name="timestamp",
|
||||
field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
except Exception:
|
||||
pass # Index may already exist
|
||||
self._collection_ensured = True
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
@@ -105,20 +114,28 @@ class QdrantService:
|
||||
)
|
||||
return point_id
|
||||
|
||||
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]:
|
||||
"""Semantic search for relevant turns, filtered by type."""
|
||||
async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated", entry_types: Optional[List[str]] = None) -> List[Dict]:
|
||||
"""Semantic search for relevant turns, filtered by type(s)."""
|
||||
await self._ensure_collection()
|
||||
|
||||
embedding = await self.get_embedding(query)
|
||||
|
||||
if entry_types and len(entry_types) > 1:
|
||||
type_filter = Filter(
|
||||
should=[FieldCondition(key="type", match=MatchValue(value=t)) for t in entry_types]
|
||||
)
|
||||
else:
|
||||
filter_type = entry_types[0] if entry_types else entry_type
|
||||
type_filter = Filter(
|
||||
must=[FieldCondition(key="type", match=MatchValue(value=filter_type))]
|
||||
)
|
||||
|
||||
results = await self.client.query_points(
|
||||
collection_name=self.collection,
|
||||
query=embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="type", match=MatchValue(value=entry_type))]
|
||||
)
|
||||
query_filter=type_filter
|
||||
)
|
||||
|
||||
return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points]
|
||||
@@ -127,20 +144,28 @@ class QdrantService:
|
||||
"""Get recent turns from Qdrant (both raw and curated)."""
|
||||
await self._ensure_collection()
|
||||
|
||||
try:
|
||||
from qdrant_client.models import OrderBy
|
||||
points, _ = await self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
limit=limit * 2,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
order_by=OrderBy(key="timestamp", direction="desc")
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: fetch extra points and sort client-side
|
||||
points, _ = await self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
limit=limit * 5,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Sort by timestamp descending
|
||||
sorted_points = sorted(
|
||||
points = sorted(
|
||||
points,
|
||||
key=lambda p: p.payload.get("timestamp", ""),
|
||||
reverse=True
|
||||
)
|
||||
)[:limit]
|
||||
|
||||
return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]]
|
||||
return [{"id": str(p.id), "payload": p.payload} for p in points]
|
||||
|
||||
async def delete_points(self, point_ids: List[str]) -> None:
|
||||
"""Delete points by ID."""
|
||||
|
||||
12
app/utils.py
12
app/utils.py
@@ -213,23 +213,25 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens")
|
||||
|
||||
# === LAYER 2: Semantic (curated memories) ===
|
||||
# === LAYER 2: Semantic (curated + raw memories) ===
|
||||
qdrant = get_qdrant_service()
|
||||
semantic_results = await qdrant.semantic_search(
|
||||
query=search_context if search_context else user_question,
|
||||
limit=20,
|
||||
score_threshold=config.semantic_score_threshold,
|
||||
entry_type="curated"
|
||||
entry_types=["curated", "raw"]
|
||||
)
|
||||
|
||||
semantic_messages = []
|
||||
semantic_tokens_used = 0
|
||||
semantic_ids = set()
|
||||
|
||||
for result in semantic_results:
|
||||
semantic_ids.add(result.get("id"))
|
||||
payload = result.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text:
|
||||
# Parse curated turn into proper user/assistant messages
|
||||
# Parse curated/raw turn into proper user/assistant messages
|
||||
parsed = parse_curated_turn(text)
|
||||
for msg in parsed:
|
||||
msg_tokens = count_tokens(msg.get("content", ""))
|
||||
@@ -254,8 +256,10 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]:
|
||||
context_messages = []
|
||||
context_tokens_used = 0
|
||||
|
||||
# Process oldest first for chronological order
|
||||
# Process oldest first for chronological order, skip duplicates from Layer 2
|
||||
for turn in reversed(recent_turns):
|
||||
if turn.get("id") in semantic_ids:
|
||||
continue
|
||||
payload = turn.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
entry_type = payload.get("type", "raw")
|
||||
|
||||
@@ -217,12 +217,13 @@ class TestGetRecentTurns:
|
||||
mock_point2.id = "new"
|
||||
mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"}
|
||||
|
||||
mock_client.scroll = AsyncMock(return_value=([mock_point1, mock_point2], None))
|
||||
# OrderBy returns server-sorted results (newest first)
|
||||
mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None))
|
||||
|
||||
results = await svc.get_recent_turns(limit=2)
|
||||
|
||||
assert len(results) == 2
|
||||
# Newest first
|
||||
# Newest first (server-sorted via OrderBy)
|
||||
assert results[0]["id"] == "new"
|
||||
assert results[1]["id"] == "old"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user