diff --git a/app/main.py b/app/main.py index 1334765..462b4ae 100644 --- a/app/main.py +++ b/app/main.py @@ -68,7 +68,7 @@ async def lifespan(app: FastAPI): await qdrant_service.close() -app = FastAPI(title="Vera-AI", version="2.0.0", lifespan=lifespan) +app = FastAPI(title="Vera-AI", version="2.0.4", lifespan=lifespan) @app.get("/") diff --git a/app/qdrant_service.py b/app/qdrant_service.py index 90890f4..4b98d63 100644 --- a/app/qdrant_service.py +++ b/app/qdrant_service.py @@ -1,6 +1,6 @@ """Qdrant service for memory storage - ASYNC VERSION.""" from qdrant_client import AsyncQdrantClient -from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue +from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType from typing import List, Dict, Any, Optional from datetime import datetime, timezone import uuid @@ -34,6 +34,15 @@ class QdrantService: vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) ) logger.info(f"Created collection {self.collection} with vector size {self.vector_size}") + # Ensure payload index on timestamp for ordered scroll + try: + await self.client.create_payload_index( + collection_name=self.collection, + field_name="timestamp", + field_schema=PayloadSchemaType.KEYWORD + ) + except Exception: + pass # Index may already exist self._collection_ensured = True async def get_embedding(self, text: str) -> List[float]: @@ -105,20 +114,28 @@ class QdrantService: ) return point_id - async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated") -> List[Dict]: - """Semantic search for relevant turns, filtered by type.""" + async def semantic_search(self, query: str, limit: int = 10, score_threshold: float = 0.6, entry_type: str = "curated", entry_types: Optional[List[str]] = None) -> List[Dict]: + """Semantic search for relevant turns, filtered by type(s).""" await self._ensure_collection() - + embedding = await self.get_embedding(query) - + + if entry_types and len(entry_types) > 1: + type_filter = Filter( + should=[FieldCondition(key="type", match=MatchValue(value=t)) for t in entry_types] + ) + else: + filter_type = entry_types[0] if entry_types else entry_type + type_filter = Filter( + must=[FieldCondition(key="type", match=MatchValue(value=filter_type))] + ) + results = await self.client.query_points( collection_name=self.collection, query=embedding, limit=limit, score_threshold=score_threshold, - query_filter=Filter( - must=[FieldCondition(key="type", match=MatchValue(value=entry_type))] - ) + query_filter=type_filter ) return [{"id": str(r.id), "score": r.score, "payload": r.payload} for r in results.points] @@ -126,21 +143,29 @@ class QdrantService: async def get_recent_turns(self, limit: int = 20) -> List[Dict]: """Get recent turns from Qdrant (both raw and curated).""" await self._ensure_collection() - - points, _ = await self.client.scroll( - collection_name=self.collection, - limit=limit * 2, - with_payload=True - ) - - # Sort by timestamp descending - sorted_points = sorted( - points, - key=lambda p: p.payload.get("timestamp", ""), - reverse=True - ) - - return [{"id": str(p.id), "payload": p.payload} for p in sorted_points[:limit]] + + try: + from qdrant_client.models import OrderBy + points, _ = await self.client.scroll( + collection_name=self.collection, + limit=limit, + with_payload=True, + order_by=OrderBy(key="timestamp", direction="desc") + ) + except Exception: + # Fallback: fetch extra points and sort client-side + points, _ = await self.client.scroll( + collection_name=self.collection, + limit=limit * 5, + with_payload=True + ) + points = sorted( + points, + key=lambda p: p.payload.get("timestamp", ""), + reverse=True + )[:limit] + + return [{"id": str(p.id), "payload": p.payload} for p in points] async def delete_points(self, point_ids: List[str]) -> None: """Delete points by ID.""" diff --git a/app/utils.py b/app/utils.py index dfbf3f7..8882542 100644 --- a/app/utils.py +++ b/app/utils.py @@ -213,23 +213,25 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]: messages.append({"role": "system", "content": system_content}) logger.info(f"Layer 1 (system): {count_tokens(system_content)} tokens") - # === LAYER 2: Semantic (curated memories) === + # === LAYER 2: Semantic (curated + raw memories) === qdrant = get_qdrant_service() semantic_results = await qdrant.semantic_search( query=search_context if search_context else user_question, limit=20, score_threshold=config.semantic_score_threshold, - entry_type="curated" + entry_types=["curated", "raw"] ) semantic_messages = [] semantic_tokens_used = 0 - + semantic_ids = set() + for result in semantic_results: + semantic_ids.add(result.get("id")) payload = result.get("payload", {}) text = payload.get("text", "") if text: - # Parse curated turn into proper user/assistant messages + # Parse curated/raw turn into proper user/assistant messages parsed = parse_curated_turn(text) for msg in parsed: msg_tokens = count_tokens(msg.get("content", "")) @@ -254,8 +256,10 @@ async def build_augmented_messages(incoming_messages: List[Dict]) -> List[Dict]: context_messages = [] context_tokens_used = 0 - # Process oldest first for chronological order + # Process oldest first for chronological order, skip duplicates from Layer 2 for turn in reversed(recent_turns): + if turn.get("id") in semantic_ids: + continue payload = turn.get("payload", {}) text = payload.get("text", "") entry_type = payload.get("type", "raw") diff --git a/tests/test_qdrant_service.py b/tests/test_qdrant_service.py index afbe963..80076b4 100644 --- a/tests/test_qdrant_service.py +++ b/tests/test_qdrant_service.py @@ -217,12 +217,13 @@ class TestGetRecentTurns: mock_point2.id = "new" mock_point2.payload = {"timestamp": "2026-03-01T00:00:00Z", "text": "new turn"} - mock_client.scroll = AsyncMock(return_value=([mock_point1, mock_point2], None)) + # OrderBy returns server-sorted results (newest first) + mock_client.scroll = AsyncMock(return_value=([mock_point2, mock_point1], None)) results = await svc.get_recent_turns(limit=2) assert len(results) == 2 - # Newest first + # Newest first (server-sorted via OrderBy) assert results[0]["id"] == "new" assert results[1]["id"] == "old"