Skip to content

Commit 269a3ef

Browse files
committed
feat: add qdrant hook
1 parent 87fdcd8 commit 269a3ef

File tree

5 files changed

+1106
-407
lines changed

5 files changed

+1106
-407
lines changed

memstate/integrations/qdrant.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Any, Callable
2+
3+
from memstate.constants import Operation
4+
from memstate.schemas import Fact
5+
6+
try:
7+
from qdrant_client import QdrantClient, models
8+
except ImportError:
9+
raise ImportError("To use QdrantSyncHook, run: pip install qdrant-client")
10+
11+
TextFormatter = Callable[[dict[str, Any]], str]
12+
MetadataFormatter = Callable[[dict[str, Any]], dict[str, Any]]
13+
EmbeddingFunction = Callable[[str], list[float]]
14+
15+
16+
class FastEmbedEncoder:
17+
"""
18+
Default embedding implementation using FastEmbed.
19+
Used if no custom embedding_fn is provided.
20+
"""
21+
22+
def __init__(
23+
self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", options: dict[str, Any] | None = None
24+
):
25+
try:
26+
from fastembed import TextEmbedding
27+
except ImportError:
28+
raise ImportError(
29+
"FastEmbed is not installed. " "Install it via `pip install fastembed` or pass a custom `embedding_fn`."
30+
)
31+
self.model = TextEmbedding(model_name, **(options or {}))
32+
33+
def __call__(self, text: str) -> list[float]:
34+
return list(self.model.embed(text))[0].tolist()
35+
36+
37+
class QdrantSyncHook:
38+
"""
39+
encoder = FastEmbedEncoder(
40+
model_name="BAAI/bge-small-en-v1.5",
41+
options={"cuda": True}
42+
)
43+
hook = QdrantSyncHook(client, "memory", embedding_fn=encoder)
44+
45+
resp = openai.embeddings.create(input=text, model="text-embedding-3-small")
46+
openai_embedder = resp.data[0].embedding
47+
hook = QdrantSyncHook(client, "memory", embedding_fn=openai_embedder)
48+
"""
49+
50+
def __init__(
51+
self,
52+
client: QdrantClient,
53+
collection_name: str,
54+
embedding_fn: EmbeddingFunction | None = None,
55+
target_types: set[str] | None = None,
56+
text_field: str | None = None,
57+
text_formatter: TextFormatter | None = None,
58+
metadata_fields: list[str] | None = None,
59+
metadata_formatter: MetadataFormatter | None = None,
60+
distance: models.Distance = models.Distance.COSINE,
61+
) -> None:
62+
self.client = client
63+
self.collection_name = collection_name
64+
65+
self.embedding_fn = embedding_fn or FastEmbedEncoder()
66+
67+
self.target_types = target_types or set()
68+
self.distance = distance
69+
70+
if text_formatter is not None:
71+
self._extract_text = text_formatter
72+
elif text_field:
73+
self._extract_text = lambda data: str(data.get(text_field, ""))
74+
else:
75+
self._extract_text = lambda data: str(data)
76+
77+
self.metadata_fields = metadata_fields or []
78+
self.metadata_formatter = metadata_formatter
79+
80+
self._ensure_collection()
81+
82+
def _ensure_collection(self) -> None:
83+
"""
84+
Auto-detects vector size by running a dummy embedding
85+
and ensures the collection exists.
86+
"""
87+
try:
88+
dummy_vec = self.embedding_fn("test")
89+
vector_size = len(dummy_vec)
90+
except Exception as e:
91+
raise RuntimeError(f"Failed to initialize embedding function: {e}")
92+
93+
if not self.client.collection_exists(self.collection_name):
94+
self.client.create_collection(
95+
collection_name=self.collection_name,
96+
vectors_config=models.VectorParams(size=vector_size, distance=self.distance),
97+
)
98+
else:
99+
coll_info = self.client.get_collection(self.collection_name)
100+
config = coll_info.config.params.vectors
101+
102+
existing_size = None
103+
if isinstance(config, models.VectorParams):
104+
existing_size = config.size
105+
elif isinstance(config, dict) and "" in config: # Default unnamed vector
106+
existing_size = config[""].size
107+
108+
if existing_size and existing_size != vector_size:
109+
raise ValueError(
110+
f"Collection '{self.collection_name}' expects vector size {existing_size}, "
111+
f"but your embedding function produces {vector_size}. "
112+
"Mismatch detected."
113+
)
114+
115+
def _get_metadata(self, data: dict[str, Any]) -> dict[str, Any]:
116+
if self.metadata_formatter is not None:
117+
return self.metadata_formatter(data)
118+
119+
if self.metadata_fields:
120+
meta = {}
121+
for field in self.metadata_fields:
122+
val = data.get(field)
123+
if val is not None:
124+
if isinstance(val, (str, int, float, bool, list)):
125+
meta[field] = val
126+
else:
127+
meta[field] = str(val)
128+
return meta
129+
130+
return {}
131+
132+
def __call__(self, op: Operation, fact_id: str, data: Fact | None) -> None:
133+
if op == Operation.DELETE:
134+
self.client.delete(collection_name=self.collection_name, points_selector=[fact_id])
135+
return
136+
137+
if op == Operation.DISCARD_SESSION:
138+
return
139+
140+
if not data or (self.target_types and data.type not in self.target_types):
141+
return
142+
143+
if op in (Operation.COMMIT, Operation.UPDATE, Operation.COMMIT_EPHEMERAL, Operation.PROMOTE):
144+
text = self._extract_text(data.payload)
145+
if not text.strip():
146+
return
147+
148+
vector = self.embedding_fn(text)
149+
150+
meta = {"type": data.type, "source": data.source or "", "ts": str(data.ts), "document": text}
151+
user_meta = self._get_metadata(data=data.payload)
152+
meta.update(user_meta)
153+
154+
self.client.upsert(
155+
collection_name=self.collection_name,
156+
points=[models.PointStruct(id=fact_id, vector=vector, payload=meta)],
157+
)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ dependencies = [
4848
redis = ["redis>=7.1.0"]
4949
langgraph = ["langgraph>=1.0.4"]
5050
chromadb = ["chromadb>=1.3.5"]
51+
qdrant = ["qdrant-client>=1.16.2"]
5152
postgres = ["sqlalchemy>=2.0.0", "psycopg[binary]>=3.3.2"]
5253

5354
[dependency-groups]
5455
dev = [
5556
"bandit>=1.9.1",
5657
"black>=25.11.0",
5758
"fakeredis>=2.32.1",
59+
"fastembed>=0.7.4",
5860
"isort>=7.0.0",
5961
"mypy>=1.18.2",
6062
"pre-commit>=4.4.0",

tests/test_e2e_sync.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pytest
22

33
chromadb = pytest.importorskip("chromadb")
4+
qdrant_client = pytest.importorskip("qdrant_client")
45

56
from memstate import Fact, InMemoryStorage, MemoryStore
67
from memstate.integrations.chroma import ChromaSyncHook
8+
from memstate.integrations.qdrant import QdrantSyncHook
79

810

911
def test_e2e_memory_store_syncs_to_chroma():
@@ -25,3 +27,25 @@ def test_e2e_memory_store_syncs_to_chroma():
2527
assert len(results["ids"]) == 1
2628
assert results["documents"][0] == "Integration works!"
2729
assert results["metadatas"][0]["role"] == "system"
30+
31+
32+
def test_e2e_memory_store_syncs_to_qdrand():
33+
client = qdrant_client.QdrantClient(":memory:")
34+
collection_name = "e2e_test"
35+
36+
hook = QdrantSyncHook(
37+
client=client, collection_name=collection_name, text_field="content", metadata_fields=["role"]
38+
)
39+
40+
store = MemoryStore(InMemoryStorage())
41+
store.add_hook(hook=hook)
42+
43+
fact_id = store.commit(fact=Fact(type="test", payload={"content": "Integration works!", "role": "system"}))
44+
45+
points, _ = client.scroll(collection_name=collection_name, limit=10)
46+
47+
assert len(points) == 1
48+
point = points[0]
49+
assert point.id == fact_id
50+
assert point.payload["document"] == "Integration works!"
51+
assert point.payload["role"] == "system"

tests/test_qdrant_integration.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import uuid
2+
3+
import pytest
4+
5+
qdrant_client = pytest.importorskip("qdrant_client")
6+
7+
from memstate import Fact, Operation
8+
from memstate.integrations.qdrant import QdrantSyncHook
9+
10+
11+
@pytest.fixture
12+
def client():
13+
return qdrant_client.QdrantClient(":memory:")
14+
15+
16+
@pytest.fixture
17+
def collection_name():
18+
return "test_memstate_sync"
19+
20+
21+
@pytest.fixture
22+
def fact_id():
23+
return str(uuid.uuid4())
24+
25+
26+
def test_initialization_creates_collection(client, collection_name):
27+
QdrantSyncHook(client=client, collection_name=collection_name)
28+
collections = client.get_collections()
29+
assert any(c.name == collection_name for c in collections.collections)
30+
31+
32+
def test_commit_upserts_data(client, collection_name, fact_id):
33+
hook = QdrantSyncHook(client=client, collection_name=collection_name, text_field="content")
34+
hook(op=Operation.COMMIT, fact_id=fact_id, data=Fact(type="memory", payload={"content": "Hello World"}))
35+
36+
points, _ = client.scroll(collection_name=collection_name, limit=10)
37+
point = points[0]
38+
assert point.id == fact_id
39+
assert point.payload["document"] == "Hello World"
40+
assert point.payload["type"] == "memory"
41+
42+
43+
def test_promote_updates_data(client, collection_name, fact_id):
44+
hook = QdrantSyncHook(client=client, collection_name=collection_name, text_field="text", metadata_fields=["status"])
45+
46+
# Pre-seed
47+
client.upsert(
48+
collection_name=collection_name,
49+
points=[
50+
qdrant_client.models.PointStruct(
51+
id=fact_id,
52+
vector=qdrant_client.models.Document(
53+
text="Old",
54+
model="sentence-transformers/all-MiniLM-L6-v2",
55+
),
56+
payload={"status": "draft"},
57+
)
58+
],
59+
)
60+
61+
# Promote
62+
hook(
63+
op=Operation.PROMOTE, fact_id=fact_id, data=Fact(type="memory", payload={"text": "New", "status": "committed"})
64+
)
65+
66+
points, _ = client.scroll(collection_name=collection_name, limit=10)
67+
point = points[0]
68+
assert point.id == fact_id
69+
assert point.payload["document"] == "New"
70+
assert point.payload["status"] == "committed"
71+
72+
73+
def test_delete_removes_data(client, collection_name, fact_id):
74+
hook = QdrantSyncHook(client=client, collection_name=collection_name)
75+
76+
client.upsert(
77+
collection_name=collection_name,
78+
points=[
79+
qdrant_client.models.PointStruct(
80+
id=fact_id,
81+
vector=qdrant_client.models.Document(
82+
text="To delete",
83+
model="sentence-transformers/all-MiniLM-L6-v2",
84+
),
85+
)
86+
],
87+
)
88+
89+
hook(op=Operation.DELETE, fact_id=fact_id, data=None)
90+
91+
points, _ = client.scroll(collection_name=collection_name, limit=10)
92+
assert len(points) == 0
93+
94+
95+
def test_discard_session_is_ignored(client, collection_name, fact_id):
96+
hook = QdrantSyncHook(client=client, collection_name=collection_name)
97+
98+
client.upsert(
99+
collection_name=collection_name,
100+
points=[
101+
qdrant_client.models.PointStruct(
102+
id=fact_id,
103+
vector=qdrant_client.models.Document(
104+
text="Stay",
105+
model="sentence-transformers/all-MiniLM-L6-v2",
106+
),
107+
)
108+
],
109+
)
110+
111+
hook(op=Operation.DISCARD_SESSION, fact_id=fact_id, data=None)
112+
113+
points, _ = client.scroll(collection_name=collection_name, limit=10)
114+
assert len(points) == 1
115+
116+
117+
def test_text_formatter_strategy(client, collection_name, fact_id):
118+
hook = QdrantSyncHook(
119+
client=client, collection_name=collection_name, text_formatter=lambda d: f"{d['key']}: {d['val']}"
120+
)
121+
hook(op=Operation.COMMIT, fact_id=fact_id, data=Fact(type="memory", payload={"key": "A", "val": "B"}))
122+
123+
points, _ = client.scroll(collection_name=collection_name, limit=10)
124+
point = points[0]
125+
assert point.id == fact_id
126+
assert point.payload["document"] == "A: B"
127+
128+
129+
def test_fallback_missing_text_skips_upsert(client, collection_name, fact_id):
130+
hook = QdrantSyncHook(client=client, collection_name=collection_name, text_field="missing_field")
131+
hook(op=Operation.COMMIT, fact_id=fact_id, data=Fact(type="memory", payload={"other": "stuff"}))
132+
133+
points, _ = client.scroll(collection_name=collection_name, limit=10)
134+
assert len(points) == 0

0 commit comments

Comments
 (0)