Skip to content

Commit d422653

Browse files
committed
fix: atomic consistency in commit method
1 parent 7a2c943 commit d422653

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

examples/rag_hook_demo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import sys
3-
from typing import Any, Dict, List
3+
from typing import Dict, List
44

55
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
66

@@ -45,11 +45,11 @@ def __init__(self, vector_db: MockVectorDB):
4545
# We are only interested in facts of the "knowledge_base" type.
4646
self.target_types = {"knowledge_base", "chat_log"}
4747

48-
def __call__(self, op: str, fact_id: str, data: Dict[str, Any] | None):
48+
def __call__(self, op: str, fact_id: str, data: Fact | None):
4949
# data is the state of the fact. If DELETE, this is the state BEFORE deletion.
5050

5151
# Type checking (do not vectorize system data)
52-
if not data or data.get("type") not in self.target_types:
52+
if not data or data.type not in self.target_types:
5353
return
5454

5555
# Processing of deletion
@@ -58,7 +58,7 @@ def __call__(self, op: str, fact_id: str, data: Dict[str, Any] | None):
5858
return
5959

6060
# Text extraction for vectorization
61-
payload = data.get("payload", {})
61+
payload = data.payload
6262
# Trying to find a text field
6363
text_content = payload.get("content") or payload.get("message") or payload.get("summary")
6464

memstate/storage.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,17 @@ def commit(
6868
ephemeral: bool = False,
6969
actor: str | None = None,
7070
reason: str | None = None,
71-
):
71+
) -> str:
7272
with self._lock:
7373
validated_payload = self._schema_registry.validate(fact.type, fact.payload)
7474
fact.payload = validated_payload
7575

7676
if session_id:
7777
fact.session_id = session_id
7878

79+
previous_state = None
80+
op = Operation.COMMIT
81+
7982
constraint = self._constraints.get(fact.type)
8083

8184
if constraint and constraint.singleton_key:
@@ -86,26 +89,37 @@ def commit(
8689

8790
if matches:
8891
existing_raw = matches[0]
89-
existing_id = existing_raw["id"]
90-
9192
if constraint.immutable:
9293
raise ConflictError(f"Immutable constraint violation: {fact.type}:{key_val}")
9394

94-
before = copy.deepcopy(existing_raw)
95-
fact.id = existing_id
96-
after = fact.model_dump()
95+
# We found a duplicate, so this is an UPDATE of an existing one
96+
previous_state = copy.deepcopy(existing_raw)
97+
fact.id = existing_raw["id"] # We replace the ID of the new fact with the old one
98+
op = Operation.UPDATE
99+
100+
if op != Operation.UPDATE:
101+
existing = self.storage.load(fact.id)
102+
if existing:
103+
previous_state = copy.deepcopy(existing)
104+
op = Operation.UPDATE
105+
else:
106+
op = Operation.COMMIT_EPHEMERAL if ephemeral else Operation.COMMIT
107+
108+
try:
109+
new_state = fact.model_dump()
110+
self.storage.save(new_state)
111+
self._log_tx(op, fact.id, previous_state, new_state, actor, reason)
112+
self._notify_hooks(op, fact.id, fact)
97113

98-
self.storage.save(after)
99-
self._log_tx(Operation.UPDATE, existing_id, before, after, actor, reason)
100-
return existing_id
114+
return fact.id
101115

102-
existing = self.storage.load(fact.id)
103-
op = Operation.UPDATE if existing else (Operation.COMMIT_EPHEMERAL if ephemeral else Operation.COMMIT)
116+
except HookError as e:
117+
if op == Operation.UPDATE and previous_state:
118+
self.storage.save(previous_state)
119+
else:
120+
self.storage.delete(fact.id)
104121

105-
self.storage.save(fact.model_dump())
106-
self._log_tx(op, fact.id, existing, fact.model_dump(), actor, reason)
107-
self._notify_hooks(op, fact.id, fact)
108-
return fact.id
122+
raise e
109123

110124
def update(self, fact_id: str, patch: dict[str, Any], actor: str | None = None, reason: str | None = None) -> str:
111125
with self._lock:

tests/test_storage.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def crashing_hook(op, fid, data):
109109
memory.commit(Fact(type="user", payload={"name": "Survivor", "age": 50}))
110110

111111
facts = memory.query(filters={"payload.name": "Survivor"})
112-
assert len(facts) == 1
113-
assert facts[0]["payload"]["name"] == "Survivor"
112+
assert len(facts) == 0
114113

115114

116115
def test_ephemeral_session_discard(memory):

0 commit comments

Comments
 (0)