6549fcbce8
Stefan wollte ne richtige Suche statt nur "klingt aehnlich". Beide Modi sind jetzt verfuegbar, Default ist Volltext: - 📝 Wortlich (Substring, case-insensitive ueber Title + Content + Category + Tags) — neuer Endpoint /memory/search-text. Full-Scan via Qdrant scroll, k=50. Findet "cessna" exakt im Content. Bei kleiner DB (<1000 Eintraege) unkritisch performant. - 🧠 Semantisch (Embedder + score_threshold 0.30) — bestehender /memory/search Endpoint. Findet konzeptuell verwandte Eintraege. Diagnostic UI: Dropdown neben dem Suchfeld zum Modus-Wechsel. Info-Banner zeigt klar welcher Modus aktiv ist. Warum Wortlich Default: bei kleiner DB liefert Semantic gern False Positives mit Score 0.30-0.45 fuer komplett unverwandte Begriffe (z.B. "cessna" matched "Tageslog fuehren" mit 0.43). Wortlich ist deterministisch und vermeidet das Rauschen. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
269 lines
9.3 KiB
Python
269 lines
9.3 KiB
Python
"""
|
|
Vector-Store-Wrapper um Qdrant.
|
|
|
|
Eine Collection "aria_memory" haelt ALLE Memory-Punkte.
|
|
Trennung nach Type/Pinned-Status via Payload-Filter.
|
|
|
|
Punkt-Schema (Payload):
|
|
type — identity | rule | preference | tool | skill | fact | conversation | reminder
|
|
category — frei, fuer UI-Gruppierung
|
|
title — kurze Ueberschrift
|
|
content — eigentlicher Text (wird embedded)
|
|
pinned — bool, True = Hot Memory (immer in Prompt)
|
|
source — import | conversation | manual
|
|
tags — Liste von Strings
|
|
created_at, updated_at — ISO-Strings
|
|
conversation_id — optional, nur fuer type=conversation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models as qm
|
|
|
|
from .embedder import VECTOR_DIM
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
COLLECTION = "aria_memory"
|
|
|
|
|
|
class MemoryType(str, Enum):
|
|
IDENTITY = "identity"
|
|
RULE = "rule"
|
|
PREFERENCE = "preference"
|
|
TOOL = "tool"
|
|
SKILL = "skill"
|
|
FACT = "fact"
|
|
CONVERSATION = "conversation"
|
|
REMINDER = "reminder"
|
|
|
|
|
|
@dataclass
|
|
class MemoryPoint:
|
|
id: str
|
|
type: str
|
|
title: str
|
|
content: str
|
|
pinned: bool = False
|
|
category: str = ""
|
|
source: str = "manual"
|
|
tags: List[str] = field(default_factory=list)
|
|
created_at: str = ""
|
|
updated_at: str = ""
|
|
conversation_id: Optional[str] = None
|
|
score: Optional[float] = None # nur bei Search gesetzt
|
|
|
|
def to_payload(self) -> dict:
|
|
p = {
|
|
"type": self.type,
|
|
"title": self.title,
|
|
"content": self.content,
|
|
"pinned": self.pinned,
|
|
"category": self.category,
|
|
"source": self.source,
|
|
"tags": self.tags,
|
|
"created_at": self.created_at,
|
|
"updated_at": self.updated_at,
|
|
}
|
|
if self.conversation_id:
|
|
p["conversation_id"] = self.conversation_id
|
|
return p
|
|
|
|
@classmethod
|
|
def from_qdrant(cls, point) -> "MemoryPoint":
|
|
payload = point.payload or {}
|
|
return cls(
|
|
id=str(point.id),
|
|
type=payload.get("type", "fact"),
|
|
title=payload.get("title", ""),
|
|
content=payload.get("content", ""),
|
|
pinned=payload.get("pinned", False),
|
|
category=payload.get("category", ""),
|
|
source=payload.get("source", "manual"),
|
|
tags=payload.get("tags", []),
|
|
created_at=payload.get("created_at", ""),
|
|
updated_at=payload.get("updated_at", ""),
|
|
conversation_id=payload.get("conversation_id"),
|
|
score=getattr(point, "score", None),
|
|
)
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
class VectorStore:
|
|
def __init__(self, host: str, port: int = 6333):
|
|
self.client = QdrantClient(host=host, port=port)
|
|
self._ensure_collection()
|
|
|
|
def _ensure_collection(self):
|
|
existing = [c.name for c in self.client.get_collections().collections]
|
|
if COLLECTION not in existing:
|
|
logger.info("Erstelle Collection %s ...", COLLECTION)
|
|
self.client.create_collection(
|
|
collection_name=COLLECTION,
|
|
vectors_config=qm.VectorParams(size=VECTOR_DIM, distance=qm.Distance.COSINE),
|
|
)
|
|
# Indexe fuer typische Filter-Felder
|
|
for field_name in ("type", "pinned", "category", "source", "migration_key"):
|
|
self.client.create_payload_index(
|
|
collection_name=COLLECTION,
|
|
field_name=field_name,
|
|
field_schema=qm.PayloadSchemaType.KEYWORD if field_name != "pinned"
|
|
else qm.PayloadSchemaType.BOOL,
|
|
)
|
|
|
|
# ─── Schreib-Operationen ─────────────────────────────────────────
|
|
|
|
def upsert(self, point: MemoryPoint, vector: List[float]) -> str:
|
|
if not point.id:
|
|
point.id = str(uuid.uuid4())
|
|
if not point.created_at:
|
|
point.created_at = _now()
|
|
point.updated_at = _now()
|
|
|
|
self.client.upsert(
|
|
collection_name=COLLECTION,
|
|
points=[qm.PointStruct(id=point.id, vector=vector, payload=point.to_payload())],
|
|
)
|
|
return point.id
|
|
|
|
def delete(self, point_id: str):
|
|
self.client.delete(
|
|
collection_name=COLLECTION,
|
|
points_selector=qm.PointIdsList(points=[point_id]),
|
|
)
|
|
|
|
# ─── Lese-Operationen ────────────────────────────────────────────
|
|
|
|
def get(self, point_id: str) -> Optional[MemoryPoint]:
|
|
result = self.client.retrieve(collection_name=COLLECTION, ids=[point_id], with_payload=True)
|
|
if not result:
|
|
return None
|
|
return MemoryPoint.from_qdrant(result[0])
|
|
|
|
def list_pinned(self) -> List[MemoryPoint]:
|
|
"""Alle pinned Punkte — Hot Memory."""
|
|
return self._scroll(filter=qm.Filter(must=[
|
|
qm.FieldCondition(key="pinned", match=qm.MatchValue(value=True))
|
|
]))
|
|
|
|
def list_by_type(self, type_: str, limit: int = 100) -> List[MemoryPoint]:
|
|
return self._scroll(
|
|
filter=qm.Filter(must=[
|
|
qm.FieldCondition(key="type", match=qm.MatchValue(value=type_))
|
|
]),
|
|
limit=limit,
|
|
)
|
|
|
|
def list_all(self, limit: int = 1000) -> List[MemoryPoint]:
|
|
return self._scroll(filter=None, limit=limit)
|
|
|
|
def _scroll(self, filter, limit: int = 1000) -> List[MemoryPoint]:
|
|
points, _ = self.client.scroll(
|
|
collection_name=COLLECTION,
|
|
scroll_filter=filter,
|
|
limit=limit,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
return [MemoryPoint.from_qdrant(p) for p in points]
|
|
|
|
def search(
|
|
self,
|
|
query_vector: List[float],
|
|
k: int = 5,
|
|
type_filter: Optional[str] = None,
|
|
exclude_pinned: bool = True,
|
|
score_threshold: Optional[float] = None,
|
|
) -> List[MemoryPoint]:
|
|
"""Semantische Search. Standard: pinned-Punkte ausgeschlossen
|
|
(die kommen separat via list_pinned in den Prompt).
|
|
|
|
score_threshold: nur Treffer mit Cosine-Similarity >= Schwelle
|
|
zurueckgeben. None = keine Filterung. MiniLM-multilingual liefert
|
|
typischerweise 0.3-0.6 fuer relevante Treffer; <0.25 ist Rauschen."""
|
|
must = []
|
|
must_not = []
|
|
if type_filter:
|
|
must.append(qm.FieldCondition(key="type", match=qm.MatchValue(value=type_filter)))
|
|
if exclude_pinned:
|
|
must_not.append(qm.FieldCondition(key="pinned", match=qm.MatchValue(value=True)))
|
|
|
|
flt = qm.Filter(must=must or None, must_not=must_not or None)
|
|
|
|
results = self.client.search(
|
|
collection_name=COLLECTION,
|
|
query_vector=query_vector,
|
|
query_filter=flt if (must or must_not) else None,
|
|
limit=k,
|
|
with_payload=True,
|
|
score_threshold=score_threshold,
|
|
)
|
|
return [MemoryPoint.from_qdrant(p) for p in results]
|
|
|
|
def count(self) -> int:
|
|
return self.client.count(collection_name=COLLECTION, exact=True).count
|
|
|
|
def search_text(
|
|
self,
|
|
query: str,
|
|
k: int = 20,
|
|
type_filter: Optional[str] = None,
|
|
exclude_pinned: bool = False,
|
|
) -> List[MemoryPoint]:
|
|
"""Volltext-Substring-Suche (case-insensitive) ueber Title +
|
|
Content + Category + Tags. Im Gegensatz zu search() ist das KEIN
|
|
Semantic-Match — nur exakte Wort-/Teilwort-Treffer.
|
|
|
|
Full-Scan ueber alle (gefilteren) Punkte. Bei der erwarteten
|
|
Groessenordnung (< 1000) unkritisch."""
|
|
q = (query or "").strip().lower()
|
|
if not q:
|
|
return []
|
|
must = []
|
|
must_not = []
|
|
if type_filter:
|
|
must.append(qm.FieldCondition(key="type", match=qm.MatchValue(value=type_filter)))
|
|
if exclude_pinned:
|
|
must_not.append(qm.FieldCondition(key="pinned", match=qm.MatchValue(value=True)))
|
|
flt = qm.Filter(must=must or None, must_not=must_not or None) if (must or must_not) else None
|
|
|
|
matches: List[MemoryPoint] = []
|
|
offset = None
|
|
while True:
|
|
points, offset = self.client.scroll(
|
|
collection_name=COLLECTION,
|
|
scroll_filter=flt,
|
|
limit=200,
|
|
offset=offset,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
for p in points:
|
|
payload = p.payload or {}
|
|
tags = payload.get("tags")
|
|
tags_str = " ".join(tags) if isinstance(tags, list) else ""
|
|
haystack = " ".join([
|
|
str(payload.get("title", "")),
|
|
str(payload.get("content", "")),
|
|
str(payload.get("category", "")),
|
|
tags_str,
|
|
]).lower()
|
|
if q in haystack:
|
|
matches.append(MemoryPoint.from_qdrant(p))
|
|
if len(matches) >= k:
|
|
return matches
|
|
if not offset:
|
|
break
|
|
return matches
|