Skip to content

Commit 681eb4b

Browse files
committed
feat: add ChromaDB sync hook
1 parent 438a587 commit 681eb4b

File tree

8 files changed

+1911
-35
lines changed

8 files changed

+1911
-35
lines changed

memstate/integrations/chroma.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Any, Callable
2+
3+
from ..constants import Operation
4+
from ..schemas import Fact
5+
6+
try:
7+
from chromadb import EmbeddingFunction
8+
from chromadb.api import ClientAPI, Embeddable
9+
except ImportError:
10+
raise ImportError("pip install chromadb")
11+
12+
TextFormatter = Callable[[dict[str, Any]], str]
13+
MetadataFormatter = Callable[[dict[str, Any]], dict[str, Any]]
14+
15+
16+
class ChromaSyncHook:
17+
def __init__(
18+
self,
19+
client: ClientAPI,
20+
collection_name: str,
21+
embedding_fn: EmbeddingFunction[Embeddable] | None = None,
22+
target_types: set[str] | None = None,
23+
text_field: str | None = None,
24+
text_formatter: TextFormatter | None = None,
25+
metadata_fields: list[str] | None = None,
26+
metadata_formatter: MetadataFormatter | None = None,
27+
):
28+
"""
29+
:param client: Initialized Chroma Client.
30+
:param collection_name: collection name.
31+
:param embedding_fn: Function (text -> vector). If None, Chroma uses default.
32+
:param target_types: Types of facts for synchronization (to avoid garbage).
33+
:param text_field: Field name of text in fact.
34+
:param text_formatter: Function.
35+
:param metadata_fields: Fields of metadata in fact.
36+
:param metadata_formatter: Function.
37+
"""
38+
self.client = client
39+
self.collection = client.get_or_create_collection(
40+
name=collection_name,
41+
embedding_function=embedding_fn,
42+
)
43+
self.target_types = target_types or set()
44+
45+
if text_formatter is not None:
46+
self._extract_text = text_formatter
47+
elif text_field:
48+
self._extract_text = lambda data: str(data.get(text_field, ""))
49+
else:
50+
self._extract_text = lambda data: str(data)
51+
52+
self.metadata_fields = metadata_fields or []
53+
self.metadata_formatter = metadata_formatter
54+
55+
def _get_metadata(self, data: dict[str, Any]) -> dict[str, Any]:
56+
if self.metadata_formatter is not None:
57+
return self.metadata_formatter(data)
58+
59+
if self.metadata_fields:
60+
meta = {}
61+
for field in self.metadata_fields:
62+
val = data.get(field)
63+
if val is not None:
64+
if isinstance(val, (str, int, float, bool)):
65+
meta[field] = val
66+
else:
67+
meta[field] = str(val)
68+
return meta
69+
70+
return {}
71+
72+
def __call__(self, op: Operation, fact_id: str, data: Fact | None) -> None:
73+
if op == Operation.DELETE:
74+
self.collection.delete(ids=[fact_id])
75+
return
76+
77+
if op == Operation.DISCARD_SESSION:
78+
return
79+
80+
if not data or (self.target_types and data.type not in self.target_types):
81+
return
82+
83+
text = self._extract_text(data.payload)
84+
85+
if not text.strip():
86+
return
87+
88+
if op in (Operation.COMMIT, Operation.UPDATE, Operation.COMMIT_EPHEMERAL, Operation.PROMOTE):
89+
meta = {"type": data.type, "source": data.source or "", "ts": str(data.ts)}
90+
metadata = self._get_metadata(data=data.payload)
91+
meta.update(metadata)
92+
93+
self.collection.upsert(ids=[fact_id], documents=[text], metadatas=[meta])

memstate/schemas.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import uuid
2+
from datetime import datetime, timezone
3+
from typing import Any
4+
5+
from pydantic import BaseModel, Field
6+
7+
from .constants import Operation
8+
9+
10+
class Fact(BaseModel):
11+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
12+
type: str
13+
payload: dict[str, Any]
14+
source: str | None = None
15+
session_id: str | None = None
16+
ts: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
17+
18+
19+
class TxEntry(BaseModel):
20+
uuid: str = Field(default_factory=lambda: str(uuid.uuid4()))
21+
seq: int
22+
ts: datetime
23+
op: Operation
24+
fact_id: str | None
25+
fact_before: dict[str, Any] | None = None
26+
fact_after: dict[str, Any] | None = None
27+
actor: str | None = None
28+
reason: str | None = None

memstate/storage.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,16 @@
11
import copy
22
import threading
3-
import uuid
43
from datetime import datetime, timezone
54
from typing import Any, Callable
65

7-
from pydantic import BaseModel, Field, ValidationError
6+
from pydantic import BaseModel, ValidationError
87

98
from .backends.base import StorageBackend
109
from .constants import Operation
1110
from .exceptions import ConflictError, HookError, MemoryStoreError, ValidationFailed
11+
from .schemas import Fact, TxEntry
1212

13-
14-
class Fact(BaseModel):
15-
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
16-
type: str
17-
payload: dict[str, Any]
18-
source: str | None = None
19-
session_id: str | None = None
20-
ts: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
21-
22-
23-
class TxEntry(BaseModel):
24-
uuid: str = Field(default_factory=lambda: str(uuid.uuid4()))
25-
seq: int
26-
ts: datetime
27-
op: Operation
28-
fact_id: str | None
29-
fact_before: dict[str, Any] | None = None
30-
fact_after: dict[str, Any] | None = None
31-
actor: str | None = None
32-
reason: str | None = None
33-
34-
35-
MemoryHook = Callable[[str, str, dict[str, Any] | None], None]
13+
MemoryHook = Callable[[Operation, str, Fact | None], None]
3614

3715

3816
class SchemaRegistry:
@@ -76,7 +54,7 @@ def register_schema(self, typename: str, model: type[BaseModel], constraint: Con
7654
def add_hook(self, hook: MemoryHook):
7755
self._hooks.append(hook)
7856

79-
def _notify_hooks(self, op: Operation, fact_id: str, data: dict[str, Any] | None) -> None:
57+
def _notify_hooks(self, op: Operation, fact_id: str, data: Fact | None) -> None:
8058
for hook in self._hooks:
8159
try:
8260
hook(op, fact_id, data)
@@ -126,7 +104,7 @@ def commit(
126104

127105
self.storage.save(fact.model_dump())
128106
self._log_tx(op, fact.id, existing, fact.model_dump(), actor, reason)
129-
self._notify_hooks(op, fact.id, fact.model_dump())
107+
self._notify_hooks(op, fact.id, fact)
130108
return fact.id
131109

132110
def update(self, fact_id: str, patch: dict[str, Any], actor: str | None = None, reason: str | None = None) -> str:
@@ -142,7 +120,7 @@ def update(self, fact_id: str, patch: dict[str, Any], actor: str | None = None,
142120

143121
self.storage.save(existing)
144122
self._log_tx(Operation.UPDATE, fact_id, before, existing, actor, reason)
145-
self._notify_hooks(Operation.UPDATE, fact_id, existing)
123+
self._notify_hooks(Operation.UPDATE, fact_id, Fact(**existing))
146124
return fact_id
147125

148126
def delete(self, fact_id: str, actor: str | None = None, reason: str | None = None) -> str:
@@ -153,7 +131,7 @@ def delete(self, fact_id: str, actor: str | None = None, reason: str | None = No
153131

154132
self.storage.delete(fact_id)
155133
self._log_tx(Operation.DELETE, fact_id, existing, None, actor, reason)
156-
self._notify_hooks(Operation.DELETE, fact_id, existing)
134+
self._notify_hooks(Operation.DELETE, fact_id, Fact(**existing))
157135
return fact_id
158136

159137
def get(self, fact_id: str) -> dict[str, Any] | None:
@@ -189,7 +167,7 @@ def promote_session(
189167

190168
promoted.append(fact_dict["id"])
191169
self._log_tx(Operation.PROMOTE, fact_dict["id"], before, fact_dict, actor, reason)
192-
self._notify_hooks(Operation.PROMOTE, fact_dict["id"], fact_dict)
170+
self._notify_hooks(Operation.PROMOTE, fact_dict["id"], Fact(**fact_dict))
193171

194172
return promoted
195173

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
[project.optional-dependencies]
3232
redis = ["redis>=7.1.0"]
3333
langgraph = ["langgraph>=1.0.4"]
34+
chromadb = ["chromadb>=1.3.5"]
3435

3536
[dependency-groups]
3637
dev = [

tests/test_chroma_integration.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import pytest
2+
3+
chromadb = pytest.importorskip("chromadb")
4+
5+
from memstate.constants import Operation
6+
from memstate.integrations.chroma import ChromaSyncHook
7+
from memstate.schemas import Fact
8+
9+
10+
@pytest.fixture
11+
def chroma_client():
12+
return chromadb.Client()
13+
14+
15+
@pytest.fixture
16+
def collection_name():
17+
return "test_memstate_sync"
18+
19+
20+
def test_initialization_creates_collection(chroma_client, collection_name):
21+
ChromaSyncHook(client=chroma_client, collection_name=collection_name)
22+
collections = chroma_client.list_collections()
23+
assert any(c.name == collection_name for c in collections)
24+
25+
26+
def test_commit_upserts_data(chroma_client, collection_name):
27+
hook = ChromaSyncHook(client=chroma_client, collection_name=collection_name, text_field="content")
28+
hook(op=Operation.COMMIT, fact_id="fact_1", data=Fact(type="memory", payload={"content": "Hello World"}))
29+
30+
coll = chroma_client.get_collection(collection_name)
31+
result = coll.get(ids=["fact_1"])
32+
assert result["documents"][0] == "Hello World"
33+
assert result["metadatas"][0]["type"] == "memory"
34+
35+
36+
def test_promote_updates_data(chroma_client, collection_name):
37+
hook = ChromaSyncHook(
38+
client=chroma_client, collection_name=collection_name, text_field="text", metadata_fields=["status"]
39+
)
40+
coll = chroma_client.get_collection(collection_name)
41+
42+
# Pre-seed
43+
coll.add(ids=["fact_1"], documents=["Old"], metadatas=[{"status": "draft"}])
44+
45+
# Promote
46+
hook(
47+
op=Operation.PROMOTE, fact_id="fact_1", data=Fact(type="memory", payload={"text": "New", "status": "committed"})
48+
)
49+
50+
result = coll.get(ids=["fact_1"])
51+
assert result["documents"][0] == "New"
52+
assert result["metadatas"][0]["status"] == "committed"
53+
54+
55+
def test_delete_removes_data(chroma_client, collection_name):
56+
hook = ChromaSyncHook(client=chroma_client, collection_name=collection_name)
57+
coll = chroma_client.get_collection(collection_name)
58+
59+
coll.add(ids=["del_1"], documents=["To delete"])
60+
61+
hook(op=Operation.DELETE, fact_id="del_1", data=None)
62+
63+
result = coll.get(ids=["del_1"])
64+
assert len(result["ids"]) == 0
65+
66+
67+
def test_discard_session_is_ignored(chroma_client, collection_name):
68+
hook = ChromaSyncHook(client=chroma_client, collection_name=collection_name)
69+
coll = chroma_client.get_collection(collection_name)
70+
71+
coll.add(ids=["safe_1"], documents=["Stay"])
72+
73+
hook(op=Operation.DISCARD_SESSION, fact_id="safe_1", data=None)
74+
75+
result = coll.get(ids=["safe_1"])
76+
assert len(result["ids"]) == 1
77+
78+
79+
def test_text_formatter_strategy(chroma_client, collection_name):
80+
hook = ChromaSyncHook(
81+
client=chroma_client, collection_name=collection_name, text_formatter=lambda d: f"{d['key']}: {d['val']}"
82+
)
83+
hook(op=Operation.COMMIT, fact_id="fmt_1", data=Fact(type="memory", payload={"key": "A", "val": "B"}))
84+
85+
coll = chroma_client.get_collection(collection_name)
86+
assert coll.get(ids=["fmt_1"])["documents"][0] == "A: B"
87+
88+
89+
def test_fallback_missing_text_skips_upsert(chroma_client, collection_name):
90+
hook = ChromaSyncHook(client=chroma_client, collection_name=collection_name, text_field="missing_field")
91+
hook(op=Operation.COMMIT, fact_id="bad_1", data=Fact(type="memory", payload={"other": "stuff"}))
92+
93+
coll = chroma_client.get_collection(collection_name)
94+
result = coll.get(ids=["bad_1"])
95+
assert len(result["ids"]) == 0

tests/test_e2e_sync.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
3+
chromadb = pytest.importorskip("chromadb")
4+
5+
from memstate.backends.inmemory import InMemoryStorage
6+
from memstate.integrations.chroma import ChromaSyncHook
7+
from memstate.schemas import Fact
8+
from memstate.storage import MemoryStore
9+
10+
11+
def test_e2e_memory_store_syncs_to_chroma():
12+
chroma_client = chromadb.Client()
13+
collection_name = "e2e_test"
14+
15+
hook = ChromaSyncHook(
16+
client=chroma_client, collection_name=collection_name, text_field="content", metadata_fields=["role"]
17+
)
18+
19+
store = MemoryStore(InMemoryStorage())
20+
store.add_hook(hook=hook)
21+
22+
store.commit(fact=Fact(type="test", payload={"content": "Integration works!", "role": "system"}))
23+
24+
coll = chroma_client.get_collection(collection_name)
25+
results = coll.get()
26+
27+
assert len(results["ids"]) == 1
28+
assert results["documents"][0] == "Integration works!"
29+
assert results["metadatas"][0]["role"] == "system"

tests/test_storage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from memstate.backends.inmemory import InMemoryStorage
77
from memstate.exceptions import ConflictError, HookError, MemoryStoreError, ValidationFailed
8-
from memstate.storage import Constraint, Fact, MemoryStore
8+
from memstate.schemas import Fact
9+
from memstate.storage import Constraint, MemoryStore
910

1011

1112
class User(BaseModel):

0 commit comments

Comments
 (0)