Skip to content

Commit 14fe748

Browse files
committed
chore: add method get_session_facts for backends to improve performance
1 parent f2d0afa commit 14fe748

File tree

8 files changed

+127
-8
lines changed

8 files changed

+127
-8
lines changed

memstate/backends/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def remove_last_tx(self, count: int) -> None:
4545
"""Removes the last N entries from the transaction log (LIFO)."""
4646
pass
4747

48+
@abstractmethod
49+
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
50+
"""Retrieve all facts belonging to a specific session."""
51+
pass
52+
4853
def close(self) -> None:
4954
"""Cleanup resources (optional)."""
5055
pass
@@ -95,6 +100,11 @@ async def remove_last_tx(self, count: int) -> None:
95100
"""Removes the last N entries from the transaction log (LIFO)."""
96101
pass
97102

103+
@abstractmethod
104+
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
105+
"""Retrieve all facts belonging to a specific session."""
106+
pass
107+
98108
async def close(self) -> None:
99109
"""Cleanup resources asynchronously."""
100110
pass

memstate/backends/inmemory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def remove_last_tx(self, count: int) -> None:
7979
else:
8080
self._tx_log = self._tx_log[:-count]
8181

82+
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
83+
return [f for f in self._store.values() if f.get("session_id") == session_id]
84+
8285
def close(self) -> None:
8386
pass
8487

@@ -161,5 +164,8 @@ async def remove_last_tx(self, count: int) -> None:
161164
else:
162165
self._tx_log = self._tx_log[:-count]
163166

167+
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
168+
return [f for f in self._store.values() if f.get("session_id") == session_id]
169+
164170
async def close(self) -> None:
165171
pass

memstate/backends/postgres.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def remove_last_tx(self, count: int) -> None:
140140
with self._engine.begin() as conn:
141141
conn.execute(stmt)
142142

143+
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
144+
stmt = select(self._facts_table.c.doc).where(self._facts_table.c.doc["session_id"].astext == session_id)
145+
with self._engine.connect() as conn:
146+
rows = conn.execute(stmt).all()
147+
return [r[0] for r in rows]
148+
143149
def close(self) -> None:
144150
self._engine.dispose()
145151

@@ -248,5 +254,11 @@ async def remove_last_tx(self, count: int) -> None:
248254
async with self._engine.begin() as conn:
249255
await conn.execute(stmt)
250256

257+
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
258+
stmt = select(self._facts_table.c.doc).where(self._facts_table.c.doc["session_id"].astext == session_id)
259+
async with self._engine.connect() as conn:
260+
result = await conn.execute(stmt)
261+
return [r[0] for r in result.all()]
262+
251263
async def close(self) -> None:
252264
await self._engine.dispose()

memstate/backends/redis.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,25 @@ def remove_last_tx(self, count: int) -> None:
149149
return
150150
self.r.ltrim(f"{self.prefix}tx_log", count, -1)
151151

152+
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
153+
key = f"{self.prefix}session:{session_id}"
154+
ids = self.r.smembers(key)
155+
156+
if not ids:
157+
return []
158+
159+
pipe = self.r.pipeline()
160+
for i in ids:
161+
pipe.get(self._key(i))
162+
raw_docs = pipe.execute()
163+
164+
results = []
165+
for raw_doc in raw_docs:
166+
doc_str = self._to_str(raw_doc)
167+
if doc_str:
168+
results.append(json.loads(doc_str))
169+
return results
170+
152171
def close(self) -> None:
153172
if self._owns_client:
154173
self.r.close()
@@ -270,6 +289,24 @@ async def remove_last_tx(self, count: int) -> None:
270289
return
271290
await self.r.ltrim(f"{self.prefix}tx_log", count, -1)
272291

292+
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
293+
key = f"{self.prefix}session:{session_id}"
294+
ids = await self.r.smembers(key)
295+
296+
if not ids:
297+
return []
298+
299+
async with self.r.pipeline() as pipe:
300+
for i in ids:
301+
pipe.get(self._key(i))
302+
raw_docs = await pipe.execute()
303+
304+
results = []
305+
for raw_doc in raw_docs:
306+
if raw_doc:
307+
results.append(json.loads(raw_doc))
308+
return results
309+
273310
async def close(self):
274311
if self._owns_client:
275312
await self.r.aclose()

memstate/backends/sqlite.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ def remove_last_tx(self, count: int) -> None:
155155
)
156156
self._conn.commit()
157157

158+
def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
159+
with self._lock:
160+
c = self._conn.cursor()
161+
c.execute("SELECT data FROM facts WHERE json_extract(data, '$.session_id') = ?", (session_id,))
162+
return [json.loads(row["data"]) for row in c.fetchall()]
163+
158164
def close(self):
159165
if self._owns_connection:
160166
self._conn.close()
@@ -315,6 +321,14 @@ async def remove_last_tx(self, count: int) -> None:
315321
)
316322
await self._db.commit()
317323

324+
async def get_session_facts(self, session_id: str) -> list[dict[str, Any]]:
325+
async with self._lock:
326+
async with self._db.execute(
327+
"SELECT data FROM facts WHERE json_extract(data, '$.session_id') = ?", (session_id,)
328+
) as cursor:
329+
rows = await cursor.fetchall()
330+
return [json.loads(row["data"]) for row in rows]
331+
318332
async def close(self) -> None:
319333
if self._db:
320334
await self._db.close()

memstate/storage.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,7 @@ def promote_session(
242242
reason: str | None = None,
243243
) -> list[str]:
244244
with self._lock:
245-
# Find all session facts (via query or backend index)
246-
# RedisStorage and InMemory can search by session_id if we add it to the query.
247-
# For MVP, we use a query with a filter (slow on large volumes, fast with indexes)
248-
249-
# In RedisStorage, it is better to create a separate get_session_facts method, but we use query
250-
# We assume that storage stores session_id in json data
251-
candidates = self.storage.query(json_filters={"session_id": session_id})
245+
candidates = self.storage.get_session_facts(session_id)
252246

253247
promoted = []
254248
for fact_dict in candidates:
@@ -492,7 +486,7 @@ async def promote_session(
492486
reason: str | None = None,
493487
) -> list[str]:
494488
async with self._lock:
495-
candidates = await self.storage.query(json_filters={"session_id": session_id})
489+
candidates = await self.storage.get_session_facts(session_id)
496490

497491
promoted = []
498492
for fact_dict in candidates:

tests/async/test_backends_async.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,26 @@ async def test_remove_last_tx(storage):
165165

166166
logs = await storage.get_tx_log(limit=10)
167167
assert len(logs) == 0
168+
169+
170+
async def test_get_session_facts(storage):
171+
ts = datetime.now(timezone.utc).isoformat()
172+
173+
await storage.save({"id": "a1", "type": "msg", "session_id": "session_A", "payload": {"val": 1}, "ts": ts})
174+
await storage.save({"id": "a2", "type": "msg", "session_id": "session_A", "payload": {"val": 2}, "ts": ts})
175+
176+
await storage.save({"id": "b1", "type": "msg", "session_id": "session_B", "payload": {"val": 3}, "ts": ts})
177+
178+
await storage.save({"id": "g1", "type": "config", "payload": {"val": 0}, "ts": ts})
179+
180+
facts_a = await storage.get_session_facts("session_A")
181+
assert len(facts_a) == 2
182+
ids_a = sorted([f["id"] for f in facts_a])
183+
assert ids_a == ["a1", "a2"]
184+
185+
facts_b = await storage.get_session_facts("session_B")
186+
assert len(facts_b) == 1
187+
assert facts_b[0]["id"] == "b1"
188+
189+
facts_empty = await storage.get_session_facts("ghost_session")
190+
assert facts_empty == []

tests/sync/test_backends.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,26 @@ def test_remove_last_tx(storage):
157157

158158
logs = storage.get_tx_log(limit=10)
159159
assert len(logs) == 0
160+
161+
162+
def test_get_session_facts(storage):
163+
ts = datetime.now(timezone.utc).isoformat()
164+
165+
storage.save({"id": "a1", "type": "msg", "session_id": "session_A", "payload": {"val": 1}, "ts": ts})
166+
storage.save({"id": "a2", "type": "msg", "session_id": "session_A", "payload": {"val": 2}, "ts": ts})
167+
168+
storage.save({"id": "b1", "type": "msg", "session_id": "session_B", "payload": {"val": 3}, "ts": ts})
169+
170+
storage.save({"id": "g1", "type": "config", "payload": {"val": 0}, "ts": ts})
171+
172+
facts_a = storage.get_session_facts("session_A")
173+
assert len(facts_a) == 2
174+
ids_a = sorted([f["id"] for f in facts_a])
175+
assert ids_a == ["a1", "a2"]
176+
177+
facts_b = storage.get_session_facts("session_B")
178+
assert len(facts_b) == 1
179+
assert facts_b[0]["id"] == "b1"
180+
181+
facts_empty = storage.get_session_facts("ghost_session")
182+
assert facts_empty == []

0 commit comments

Comments
 (0)