Skip to content

Commit 7da36da

Browse files
authored
feat: clear memory (#239)
1 parent ae2ffbb commit 7da36da

14 files changed

Lines changed: 359 additions & 5 deletions

File tree

src/memu/app/crud.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ async def list_memory_categories(
7676
raise RuntimeError(msg)
7777
return response
7878

79+
async def clear_memory(
80+
self,
81+
where: dict[str, Any] | None = None,
82+
) -> dict[str, Any]:
83+
ctx = self._get_context()
84+
store = self._get_database()
85+
where_filters = self._normalize_where(where)
86+
87+
state: WorkflowState = {
88+
"ctx": ctx,
89+
"store": store,
90+
"where": where_filters,
91+
}
92+
93+
result = await self._run_workflow("crud_clear_memory", state)
94+
response = cast(dict[str, Any] | None, result.get("response"))
95+
if response is None:
96+
msg = "Clear memory workflow failed to produce a response"
97+
raise RuntimeError(msg)
98+
return response
99+
79100
def _build_list_memory_items_workflow(self) -> list[WorkflowStep]:
80101
steps = [
81102
WorkflowStep(
@@ -98,7 +119,7 @@ def _build_list_memory_items_workflow(self) -> list[WorkflowStep]:
98119
return steps
99120

100121
@staticmethod
101-
def _list_list_memory_items_initial_keys() -> set[str]:
122+
def _list_list_memories_initial_keys() -> set[str]:
102123
return {
103124
"ctx",
104125
"store",
@@ -126,6 +147,51 @@ def _build_list_memory_categories_workflow(self) -> list[WorkflowStep]:
126147
]
127148
return steps
128149

150+
def _build_clear_memory_workflow(self) -> list[WorkflowStep]:
151+
steps = [
152+
WorkflowStep(
153+
step_id="clear_memory_categories",
154+
role="delete_memories",
155+
handler=self._crud_clear_memory_categories,
156+
requires={"ctx", "store", "where"},
157+
produces={"deleted_categories"},
158+
capabilities={"db"},
159+
),
160+
WorkflowStep(
161+
step_id="clear_memory_items",
162+
role="delete_memories",
163+
handler=self._crud_clear_memory_items,
164+
requires={"ctx", "store", "where"},
165+
produces={"deleted_items"},
166+
capabilities={"db"},
167+
),
168+
WorkflowStep(
169+
step_id="clear_memory_resources",
170+
role="delete_memories",
171+
handler=self._crud_clear_memory_resources,
172+
requires={"ctx", "store", "where"},
173+
produces={"deleted_resources"},
174+
capabilities={"db"},
175+
),
176+
WorkflowStep(
177+
step_id="build_response",
178+
role="emit",
179+
handler=self._crud_build_clear_memory_response,
180+
requires={"ctx", "store", "deleted_categories", "deleted_items", "deleted_resources"},
181+
produces={"response"},
182+
capabilities=set(),
183+
),
184+
]
185+
return steps
186+
187+
@staticmethod
188+
def _list_clear_memories_initial_keys() -> set[str]:
189+
return {
190+
"ctx",
191+
"store",
192+
"where",
193+
}
194+
129195
def _normalize_where(self, where: Mapping[str, Any] | None) -> dict[str, Any]:
130196
"""Validate and clean the `where` scope filters against the configured user model."""
131197
if not where:
@@ -177,6 +243,39 @@ def _crud_build_list_categories_response(self, state: WorkflowState, step_contex
177243
state["response"] = response
178244
return state
179245

246+
def _crud_clear_memory_categories(self, state: WorkflowState, step_context: Any) -> WorkflowState:
247+
where_filters = state.get("where") or {}
248+
store = state["store"]
249+
deleted = store.memory_category_repo.clear_categories(where_filters)
250+
state["deleted_categories"] = deleted
251+
return state
252+
253+
def _crud_clear_memory_items(self, state: WorkflowState, step_context: Any) -> WorkflowState:
254+
where_filters = state.get("where") or {}
255+
store = state["store"]
256+
deleted = store.memory_item_repo.clear_items(where_filters)
257+
state["deleted_items"] = deleted
258+
return state
259+
260+
def _crud_clear_memory_resources(self, state: WorkflowState, step_context: Any) -> WorkflowState:
261+
where_filters = state.get("where") or {}
262+
store = state["store"]
263+
deleted = store.resource_repo.clear_resources(where_filters)
264+
state["deleted_resources"] = deleted
265+
return state
266+
267+
def _crud_build_clear_memory_response(self, state: WorkflowState, step_context: Any) -> WorkflowState:
268+
deleted_categories = state.get("deleted_categories", {})
269+
deleted_items = state.get("deleted_items", {})
270+
deleted_resources = state.get("deleted_resources", {})
271+
response = {
272+
"deleted_categories": [self._model_dump_without_embeddings(cat) for cat in deleted_categories.values()],
273+
"deleted_items": [self._model_dump_without_embeddings(item) for item in deleted_items.values()],
274+
"deleted_resources": [self._model_dump_without_embeddings(res) for res in deleted_resources.values()],
275+
}
276+
state["response"] = response
277+
return state
278+
180279
async def create_memory_item(
181280
self,
182281
*,

src/memu/app/service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _register_pipelines(self) -> None:
277277
patch_delete_initial_keys = CRUDMixin._list_delete_memory_item_initial_keys()
278278
self._pipelines.register("patch_delete", patch_delete_workflow, initial_state_keys=patch_delete_initial_keys)
279279
crud_list_items_workflow = self._build_list_memory_items_workflow()
280-
crud_list_memories_initial_keys = CRUDMixin._list_list_memory_items_initial_keys()
280+
crud_list_memories_initial_keys = CRUDMixin._list_list_memories_initial_keys()
281281
self._pipelines.register(
282282
"crud_list_memory_items", crud_list_items_workflow, initial_state_keys=crud_list_memories_initial_keys
283283
)
@@ -287,6 +287,11 @@ def _register_pipelines(self) -> None:
287287
crud_list_categories_workflow,
288288
initial_state_keys=crud_list_memories_initial_keys,
289289
)
290+
crud_clear_memory_workflow = self._build_clear_memory_workflow()
291+
crud_clear_memory_initial_keys = CRUDMixin._list_clear_memories_initial_keys()
292+
self._pipelines.register(
293+
"crud_clear_memory", crud_clear_memory_workflow, initial_state_keys=crud_clear_memory_initial_keys
294+
)
290295

291296
async def _run_workflow(self, workflow_name: str, initial_state: WorkflowState) -> WorkflowState:
292297
"""Execute a workflow through the configured runner backend."""

src/memu/database/inmemory/repositories/memory_category_repo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ def list_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, M
2323
return dict(self.categories)
2424
return {cid: cat for cid, cat in self.categories.items() if matches_where(cat, where)}
2525

26+
def clear_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryCategory]:
27+
if not where:
28+
matches = self.categories.copy()
29+
self.categories.clear()
30+
return matches
31+
matches = {cid: cat for cid, cat in self.categories.items() if matches_where(cat, where)}
32+
self.categories = {cid: cat for cid, cat in self.categories.items() if cid not in matches}
33+
return matches
34+
2635
def get_or_create_category(
2736
self, *, name: str, description: str, embedding: list[float], user_data: dict[str, Any]
2837
) -> MemoryCategory:

src/memu/database/inmemory/repositories/memory_item_repo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ def list_items(self, where: Mapping[str, Any] | None = None) -> dict[str, Memory
2222
return dict(self.items)
2323
return {mid: item for mid, item in self.items.items() if matches_where(item, where)}
2424

25+
def clear_items(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryItem]:
26+
if not where:
27+
matches = self.items.copy()
28+
self.items.clear()
29+
return matches
30+
matches = {mid: item for mid, item in self.items.items() if matches_where(item, where)}
31+
self.items = {mid: item for mid, item in self.items.items() if mid not in matches}
32+
return matches
33+
2534
def create_item(
2635
self,
2736
*,

src/memu/database/inmemory/repositories/resource_repo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ def list_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, Re
2121
return dict(self.resources)
2222
return {rid: res for rid, res in self.resources.items() if matches_where(res, where)}
2323

24+
def clear_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, Resource]:
25+
if not where:
26+
matches = self.resources.copy()
27+
self.resources.clear()
28+
return matches
29+
matches = {rid: res for rid, res in self.resources.items() if matches_where(res, where)}
30+
self.resources = {rid: res for rid, res in self.resources.items() if rid not in matches}
31+
return matches
32+
2433
def create_resource(
2534
self,
2635
*,

src/memu/database/postgres/repositories/memory_category_repo.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ def list_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, M
3737
result[cat.id] = cat
3838
return result
3939

40+
def clear_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryCategory]:
41+
from sqlmodel import delete, select
42+
43+
filters = self._build_filters(self._sqla_models.MemoryCategory, where)
44+
with self._sessions.session() as session:
45+
# First get the objects to delete
46+
rows = session.scalars(select(self._sqla_models.MemoryCategory).where(*filters)).all()
47+
deleted: dict[str, MemoryCategory] = {}
48+
for row in rows:
49+
row.embedding = self._normalize_embedding(row.embedding)
50+
deleted[row.id] = row
51+
52+
if not deleted:
53+
return {}
54+
55+
# Delete from database
56+
session.exec(delete(self._sqla_models.MemoryCategory).where(*filters))
57+
session.commit()
58+
59+
# Clean up cache
60+
for cat_id in deleted:
61+
self.categories.pop(cat_id, None)
62+
63+
return deleted
64+
4065
def get_or_create_category(
4166
self,
4267
*,

src/memu/database/postgres/repositories/memory_item_repo.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,31 @@ def list_items(self, where: Mapping[str, Any] | None = None) -> dict[str, Memory
5151
result[item.id] = item
5252
return result
5353

54+
def clear_items(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryItem]:
55+
from sqlmodel import delete, select
56+
57+
filters = self._build_filters(self._sqla_models.MemoryItem, where)
58+
with self._sessions.session() as session:
59+
# First get the objects to delete
60+
rows = session.scalars(select(self._sqla_models.MemoryItem).where(*filters)).all()
61+
deleted: dict[str, MemoryItem] = {}
62+
for row in rows:
63+
row.embedding = self._normalize_embedding(row.embedding)
64+
deleted[row.id] = row
65+
66+
if not deleted:
67+
return {}
68+
69+
# Delete from database
70+
session.exec(delete(self._sqla_models.MemoryItem).where(*filters))
71+
session.commit()
72+
73+
# Clean up cache
74+
for item_id in deleted:
75+
self.items.pop(item_id, None)
76+
77+
return deleted
78+
5479
def create_item(
5580
self,
5681
*,

src/memu/database/postgres/repositories/resource_repo.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ def list_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, Re
3737
result[res.id] = res
3838
return result
3939

40+
def clear_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, Resource]:
41+
from sqlmodel import delete, select
42+
43+
filters = self._build_filters(self._sqla_models.Resource, where)
44+
with self._sessions.session() as session:
45+
# First get the objects to delete
46+
rows = session.scalars(select(self._sqla_models.Resource).where(*filters)).all()
47+
deleted: dict[str, Resource] = {}
48+
for row in rows:
49+
row.embedding = self._normalize_embedding(row.embedding)
50+
deleted[row.id] = row
51+
52+
if not deleted:
53+
return {}
54+
55+
# Delete from database
56+
session.exec(delete(self._sqla_models.Resource).where(*filters))
57+
session.commit()
58+
59+
# Clean up cache
60+
for res_id in deleted:
61+
self.resources.pop(res_id, None)
62+
63+
return deleted
64+
4065
def create_resource(
4166
self,
4267
*,

src/memu/database/repositories/memory_category.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class MemoryCategoryRepo(Protocol):
1414

1515
def list_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryCategory]: ...
1616

17+
def clear_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryCategory]: ...
18+
1719
def get_or_create_category(
1820
self, *, name: str, description: str, embedding: list[float], user_data: dict[str, Any]
1921
) -> MemoryCategory: ...

src/memu/database/repositories/memory_item.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def get_item(self, item_id: str) -> MemoryItem | None: ...
1616

1717
def list_items(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryItem]: ...
1818

19+
def clear_items(self, where: Mapping[str, Any] | None = None) -> dict[str, MemoryItem]: ...
20+
1921
def create_item(
2022
self,
2123
*,

0 commit comments

Comments
 (0)