Skip to content

Commit 743f731

Browse files
committed
fix: add session_id to rollback method
1 parent 1a4ea2a commit 743f731

19 files changed

+465
-296
lines changed

examples/document_lifecycle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def run_demo():
8787
status="archived", # Metadata changed
8888
author="Alice",
8989
)
90-
memory.commit_model(fact_id=doc_id, model=updated_project)
90+
memory.commit_model(fact_id=doc_id, model=updated_project, session_id="session_1")
9191

9292
print_state("AFTER UPDATE", doc_id, sqlite, collection)
9393

@@ -96,7 +96,7 @@ def run_demo():
9696
# =========================================================================
9797
print(f"\n3️⃣ Deleting document...")
9898

99-
memory.delete(doc_id)
99+
memory.delete(session_id="session_1", fact_id=doc_id)
100100

101101
# Manual verification
102102
print(f"\n--- CHECKING STATE: AFTER DELETE ---")

examples/main_demo.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class MeetingNote(BaseModel):
4646
# Scenario 1: Agent learns about the user
4747
print("--- Step 1: Creating User Profile ---")
4848
fact_profile = UserProfile(email="alex@corp.com", full_name="Alex Dev", role="Backend")
49-
memory.commit_model(model=fact_profile, source="chat_onboarding", actor="Agent_Smith")
49+
memory.commit_model(model=fact_profile, source="chat_onboarding", session_id="session_1", actor="Agent_Smith")
5050

5151
# Let's check what was recorded
5252
saved_profile = memory.query(typename="user_profile")[0]
@@ -58,10 +58,12 @@ class MeetingNote(BaseModel):
5858
# The agent realized that Alex was actually a Senior.
5959
# He simply commits a new fact. The system will automatically find the old one by email and update it.
6060
fact_update = UserProfile(email="alex@corp.com", full_name="Alex Dev", role="Backend Lead", level="Senior")
61-
memory.commit_model(fact_update, source="linkedin_parser", actor="Agent_Smith", reason="found linkedin profile")
61+
memory.commit_model(
62+
fact_update, source="linkedin_parser", session_id="session_1", actor="Agent_Smith", reason="found linkedin profile"
63+
)
6264

6365
# We're checking. There should be one fact left, but it should be updated.
64-
profiles = memory.query(typename="user_profile")
66+
profiles = memory.query(typename="user_profile", session_id="session_1")
6567
print(f"✅ Total Profiles: {len(profiles)}")
6668
print(f"✅ Updated Level: {profiles[0]['payload']['level']}")
6769

@@ -72,7 +74,7 @@ class MeetingNote(BaseModel):
7274
topic="Salary Negotiation",
7375
summary="Alex agreed to work for free.", # Hallucination!
7476
)
75-
memory.commit_model(bad_fact, actor="Agent_Smith")
77+
memory.commit_model(bad_fact, actor="Agent_Smith", session_id="session_1")
7678
print("⚠️ Bad fact committed.")
7779

7880

@@ -81,19 +83,19 @@ class MeetingNote(BaseModel):
8183
print("\n--- Step 4: Detection & Rollback ---")
8284
# The developer or Supervisor-Agent notices an error.
8385
# Let's look at the latest transactions
84-
logs = storage.get_tx_log(limit=2)
86+
logs = storage.get_tx_log(session_id="session_1", limit=2)
8587
print(f"🔍 Last Action: {logs[0]['op']} by {logs[0]['actor']}")
8688

8789
print("↺ Rolling back 1 step...")
88-
memory.rollback(steps=1)
90+
memory.rollback(session_id="session_1", steps=1)
8991

9092
# Let's check if the bad fact has disappeared
91-
notes = memory.query(typename="meeting_note")
93+
notes = memory.query(typename="meeting_note", session_id="session_1")
9294
if not notes:
9395
print("✅ Rollback successful! The bad note is gone.")
9496
else:
9597
print("❌ Failed, note still exists.")
9698

9799
# We check that the profile (previous state) is not damaged
98-
profiles = memory.query(typename="user_profile")
100+
profiles = memory.query(typename="user_profile", session_id="session_1")
99101
print(f"✅ Profile still exists and is {profiles[0]['payload']['level']}")

examples/postgres_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def main():
9898
# Step 7: Audit Log (Compliance)
9999
print(f"\n4️⃣ Checking Transaction Log (History)...")
100100
# Assuming you implemented get_tx_log in PostgresStorage
101-
history = pg_storage.get_tx_log(limit=5)
101+
history = pg_storage.get_tx_log(session_id="session_1", limit=5)
102102

103103
for tx in history:
104104
op = tx.get("op", "UNKNOWN")
@@ -108,7 +108,7 @@ def main():
108108

109109
# Step 8: Rollback
110110
print(f"\n5️⃣ Oops! Update was a mistake. Rolling back...")
111-
memory.rollback(1)
111+
memory.rollback(session_id="session_1", steps=1)
112112

113113
final = pg_storage.load(fact_id)
114114
print(f" Restored Level: {final['payload']['level']}")

examples/rag_hook_demo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def main():
104104
doc2 = KnowledgeBase(content="Mars is known as the Red Planet due to iron oxide.")
105105
doc3 = UserPref(theme="dark") # This should NOT be in vectors (filter by type)
106106

107-
memory.commit_model(model=doc1)
108-
doc2_id = memory.commit_model(model=doc2)
109-
memory.commit_model(model=doc3)
107+
memory.commit_model(model=doc1, session_id="session_1")
108+
doc2_id = memory.commit_model(model=doc2, session_id="session_1")
109+
memory.commit_model(model=doc3, session_id="session_1")
110110

111111
print("\n--- Phase 2: RAG Search (Emulation) ---")
112112
# User asks: "Tell me about the red planet"
@@ -145,7 +145,7 @@ def main():
145145
print("\n--- Phase 4: Forgetting (Deletion) ---")
146146
# Deleting a fact from the database
147147
print("🗑 Deleting Mars fact...")
148-
memory.delete(doc2_id)
148+
memory.delete(session_id="session_1", fact_id=doc2_id)
149149

150150
# Checking the search
151151
found_ids = vector_db.search("Mars")

memstate/backends/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def append_tx(self, tx_data: dict[str, Any]) -> None:
3535
pass
3636

3737
@abstractmethod
38-
def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
38+
def get_tx_log(self, session_id: str, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
3939
"""Retrieve transaction history (newest first typically, or ordered by seq)."""
4040
pass
4141

@@ -45,13 +45,13 @@ def delete_session(self, session_id: str) -> list[str]:
4545
pass
4646

4747
@abstractmethod
48-
def remove_last_tx(self, count: int) -> None:
49-
"""Removes the last N entries from the transaction log (LIFO)."""
48+
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
49+
"""Retrieve all facts belonging to a specific session."""
5050
pass
5151

5252
@abstractmethod
53-
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
54-
"""Retrieve all facts belonging to a specific session."""
53+
def delete_txs(self, tx_uuids: list[str]) -> None:
54+
"""Delete specific transactions from the log by their UUIDs."""
5555
pass
5656

5757
def close(self) -> None:
@@ -90,7 +90,7 @@ async def append_tx(self, tx_data: dict[str, Any]) -> None:
9090
pass
9191

9292
@abstractmethod
93-
async def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
93+
async def get_tx_log(self, session_id: str, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
9494
"""Retrieve transaction history asynchronously."""
9595
pass
9696

@@ -100,13 +100,13 @@ async def delete_session(self, session_id: str) -> list[str]:
100100
pass
101101

102102
@abstractmethod
103-
async def remove_last_tx(self, count: int) -> None:
104-
"""Removes the last N entries from the transaction log (LIFO)."""
103+
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
104+
"""Retrieve all facts belonging to a specific session."""
105105
pass
106106

107107
@abstractmethod
108-
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
109-
"""Retrieve all facts belonging to a specific session."""
108+
async def delete_txs(self, tx_uuids: list[str]) -> None:
109+
"""Delete specific transactions from the log by their UUIDs."""
110110
pass
111111

112112
async def close(self) -> None:

memstate/backends/inmemory.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,13 @@ def append_tx(self, tx_data: dict[str, Any]) -> None:
150150
with self._lock:
151151
self._tx_log.append(tx_data)
152152

153-
def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
153+
def get_tx_log(self, session_id: str, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
154154
"""
155155
Retrieves and returns a portion of the transaction log. The transaction log is accessed in
156156
reverse order of insertion, i.e., the most recently added item is the first in the result.
157157
158158
Args:
159+
session_id (str): The identifier of the session whose transactions should be retrieved.
159160
limit (int): The maximum number of transaction log entries to be retrieved. Default is 100.
160161
offset (int): The starting position relative to the most recent entry that determines where to begin
161162
retrieving the log entries. Default is 0.
@@ -165,7 +166,9 @@ def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
165166
contain details of individual transaction log entries.
166167
"""
167168
with self._lock:
168-
return list(reversed(self._tx_log))[offset : offset + limit]
169+
reversed_log = reversed(self._tx_log)
170+
filtered = [tx for tx in reversed_log if tx.get("session_id") == session_id]
171+
return filtered[offset : offset + limit]
169172

170173
def delete_session(self, session_id: str) -> list[str]:
171174
"""
@@ -186,26 +189,6 @@ def delete_session(self, session_id: str) -> list[str]:
186189
del self._store[fid]
187190
return to_delete
188191

189-
def remove_last_tx(self, count: int) -> None:
190-
"""
191-
Removes a specified number of the most recent transactions from the transaction
192-
log. If the number of transactions to remove exceeds the current size of the
193-
log, the entire log will be cleared.
194-
195-
Args:
196-
count (int): The number of transactions to remove. Must be a positive integer.
197-
198-
Returns:
199-
None
200-
"""
201-
with self._lock:
202-
if count <= 0:
203-
return
204-
if count >= len(self._tx_log):
205-
self._tx_log.clear()
206-
else:
207-
self._tx_log = self._tx_log[:-count]
208-
209192
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
210193
"""
211194
Retrieves all facts associated with a specific session.
@@ -222,6 +205,25 @@ def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
222205
"""
223206
return [f for f in self._store.values() if f.get("session_id") == session_id]
224207

208+
def delete_txs(self, tx_uuids: list[str]) -> None:
209+
"""
210+
Removes a list of transactions from the transaction log whose session IDs match the provided
211+
transaction IDs. If the provided list is empty, no transactions are processed.
212+
213+
Args:
214+
tx_uuids (list[str]): A list of transaction UUIDs to be removed from the log.
215+
216+
Returns:
217+
None
218+
"""
219+
if not tx_uuids:
220+
return
221+
222+
with self._lock:
223+
ids_to_delete = set(tx_uuids)
224+
225+
self._tx_log = [tx for tx in self._tx_log if tx["uuid"] not in ids_to_delete]
226+
225227
def close(self) -> None:
226228
"""
227229
Closes the current open resource or connection.
@@ -381,12 +383,13 @@ async def append_tx(self, tx_data: dict[str, Any]) -> None:
381383
async with self._lock:
382384
self._tx_log.append(tx_data)
383385

384-
async def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
386+
async def get_tx_log(self, session_id: str, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
385387
"""
386388
Asynchronously retrieves and returns a portion of the transaction log. The transaction log is accessed in
387389
reverse order of insertion, i.e., the most recently added item is the first in the result.
388390
389391
Args:
392+
session_id (str): The identifier of the session whose transactions should be retrieved.
390393
limit (int): The maximum number of transaction log entries to be retrieved. Default is 100.
391394
offset (int): The starting position relative to the most recent entry that determines where to begin
392395
retrieving the log entries. Default is 0.
@@ -396,7 +399,9 @@ async def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str,
396399
contain details of individual transaction log entries.
397400
"""
398401
async with self._lock:
399-
return list(reversed(self._tx_log))[offset : offset + limit]
402+
reversed_log = reversed(self._tx_log)
403+
filtered = [tx for tx in reversed_log if tx.get("session_id") == session_id]
404+
return filtered[offset : offset + limit]
400405

401406
async def delete_session(self, session_id: str) -> list[str]:
402407
"""
@@ -417,27 +422,6 @@ async def delete_session(self, session_id: str) -> list[str]:
417422
del self._store[fid]
418423
return to_delete
419424

420-
async def remove_last_tx(self, count: int) -> None:
421-
"""
422-
Asynchronously removes a specified number of the most recent transactions from the transaction
423-
log. If the number of transactions to remove exceeds the current size of the
424-
log, the entire log will be cleared.
425-
426-
Args:
427-
count (int): The number of transactions to remove. Must be a positive integer.
428-
429-
Returns:
430-
None
431-
"""
432-
async with self._lock:
433-
if count <= 0:
434-
return
435-
436-
if count >= len(self._tx_log):
437-
self._tx_log.clear()
438-
else:
439-
self._tx_log = self._tx_log[:-count]
440-
441425
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
442426
"""
443427
Asynchronously retrieves all facts associated with a specific session.
@@ -454,6 +438,25 @@ async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
454438
"""
455439
return [f for f in self._store.values() if f.get("session_id") == session_id]
456440

441+
async def delete_txs(self, tx_uuids: list[str]) -> None:
442+
"""
443+
Asynchronously removes a list of transactions from the transaction log whose session IDs match the provided
444+
transaction IDs. If the provided list is empty, no transactions are processed.
445+
446+
Args:
447+
tx_uuids (list[str]): A list of transaction UUIDs to be removed from the log.
448+
449+
Returns:
450+
None
451+
"""
452+
if not tx_uuids:
453+
return
454+
455+
async with self._lock:
456+
ids_to_delete = set(tx_uuids)
457+
458+
self._tx_log = [tx for tx in self._tx_log if tx["uuid"] not in ids_to_delete]
459+
457460
async def close(self) -> None:
458461
"""
459462
Asynchronously closes the current open resource or connection.

0 commit comments

Comments
 (0)