""" Vector-Store-Wrapper um Qdrant. Eine Collection "aria_memory" haelt ALLE Memory-Punkte. Trennung nach Type/Pinned-Status via Payload-Filter. Punkt-Schema (Payload): type — identity | rule | preference | tool | skill | fact | conversation | reminder category — frei, fuer UI-Gruppierung title — kurze Ueberschrift content — eigentlicher Text (wird embedded) pinned — bool, True = Hot Memory (immer in Prompt) source — import | conversation | manual tags — Liste von Strings created_at, updated_at — ISO-Strings conversation_id — optional, nur fuer type=conversation """ from __future__ import annotations import logging import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import List, Optional from qdrant_client import QdrantClient from qdrant_client.http import models as qm from .embedder import VECTOR_DIM logger = logging.getLogger(__name__) COLLECTION = "aria_memory" class MemoryType(str, Enum): IDENTITY = "identity" RULE = "rule" PREFERENCE = "preference" TOOL = "tool" SKILL = "skill" FACT = "fact" CONVERSATION = "conversation" REMINDER = "reminder" @dataclass class MemoryPoint: id: str type: str title: str content: str pinned: bool = False category: str = "" source: str = "manual" tags: List[str] = field(default_factory=list) created_at: str = "" updated_at: str = "" conversation_id: Optional[str] = None score: Optional[float] = None # nur bei Search gesetzt def to_payload(self) -> dict: p = { "type": self.type, "title": self.title, "content": self.content, "pinned": self.pinned, "category": self.category, "source": self.source, "tags": self.tags, "created_at": self.created_at, "updated_at": self.updated_at, } if self.conversation_id: p["conversation_id"] = self.conversation_id return p @classmethod def from_qdrant(cls, point) -> "MemoryPoint": payload = point.payload or {} return cls( id=str(point.id), type=payload.get("type", "fact"), title=payload.get("title", ""), content=payload.get("content", ""), pinned=payload.get("pinned", False), category=payload.get("category", ""), source=payload.get("source", "manual"), tags=payload.get("tags", []), created_at=payload.get("created_at", ""), updated_at=payload.get("updated_at", ""), conversation_id=payload.get("conversation_id"), score=getattr(point, "score", None), ) def _now() -> str: return datetime.now(timezone.utc).isoformat() class VectorStore: def __init__(self, host: str, port: int = 6333): self.client = QdrantClient(host=host, port=port) self._ensure_collection() def _ensure_collection(self): existing = [c.name for c in self.client.get_collections().collections] if COLLECTION not in existing: logger.info("Erstelle Collection %s ...", COLLECTION) self.client.create_collection( collection_name=COLLECTION, vectors_config=qm.VectorParams(size=VECTOR_DIM, distance=qm.Distance.COSINE), ) # Indexe fuer typische Filter-Felder for field_name in ("type", "pinned", "category", "source", "migration_key"): self.client.create_payload_index( collection_name=COLLECTION, field_name=field_name, field_schema=qm.PayloadSchemaType.KEYWORD if field_name != "pinned" else qm.PayloadSchemaType.BOOL, ) # ─── Schreib-Operationen ───────────────────────────────────────── def upsert(self, point: MemoryPoint, vector: List[float]) -> str: if not point.id: point.id = str(uuid.uuid4()) if not point.created_at: point.created_at = _now() point.updated_at = _now() self.client.upsert( collection_name=COLLECTION, points=[qm.PointStruct(id=point.id, vector=vector, payload=point.to_payload())], ) return point.id def delete(self, point_id: str): self.client.delete( collection_name=COLLECTION, points_selector=qm.PointIdsList(points=[point_id]), ) # ─── Lese-Operationen ──────────────────────────────────────────── def get(self, point_id: str) -> Optional[MemoryPoint]: result = self.client.retrieve(collection_name=COLLECTION, ids=[point_id], with_payload=True) if not result: return None return MemoryPoint.from_qdrant(result[0]) def list_pinned(self) -> List[MemoryPoint]: """Alle pinned Punkte — Hot Memory.""" return self._scroll(filter=qm.Filter(must=[ qm.FieldCondition(key="pinned", match=qm.MatchValue(value=True)) ])) def list_by_type(self, type_: str, limit: int = 100) -> List[MemoryPoint]: return self._scroll( filter=qm.Filter(must=[ qm.FieldCondition(key="type", match=qm.MatchValue(value=type_)) ]), limit=limit, ) def list_all(self, limit: int = 1000) -> List[MemoryPoint]: return self._scroll(filter=None, limit=limit) def _scroll(self, filter, limit: int = 1000) -> List[MemoryPoint]: points, _ = self.client.scroll( collection_name=COLLECTION, scroll_filter=filter, limit=limit, with_payload=True, with_vectors=False, ) return [MemoryPoint.from_qdrant(p) for p in points] def search( self, query_vector: List[float], k: int = 5, type_filter: Optional[str] = None, exclude_pinned: bool = True, score_threshold: Optional[float] = None, ) -> List[MemoryPoint]: """Semantische Search. Standard: pinned-Punkte ausgeschlossen (die kommen separat via list_pinned in den Prompt). score_threshold: nur Treffer mit Cosine-Similarity >= Schwelle zurueckgeben. None = keine Filterung. MiniLM-multilingual liefert typischerweise 0.3-0.6 fuer relevante Treffer; <0.25 ist Rauschen.""" must = [] must_not = [] if type_filter: must.append(qm.FieldCondition(key="type", match=qm.MatchValue(value=type_filter))) if exclude_pinned: must_not.append(qm.FieldCondition(key="pinned", match=qm.MatchValue(value=True))) flt = qm.Filter(must=must or None, must_not=must_not or None) results = self.client.search( collection_name=COLLECTION, query_vector=query_vector, query_filter=flt if (must or must_not) else None, limit=k, with_payload=True, score_threshold=score_threshold, ) return [MemoryPoint.from_qdrant(p) for p in results] def count(self) -> int: return self.client.count(collection_name=COLLECTION, exact=True).count