Skip to content

Commit 3b67458

Browse files
authored
feat: add non-propagate option for memory patch (#386)
1 parent 163d050 commit 3b67458

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

src/memu/app/crud.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ async def create_memory_item(
283283
memory_content: str,
284284
memory_categories: list[str],
285285
user: dict[str, Any] | None = None,
286+
propagate: bool = True,
286287
) -> dict[str, Any]:
287288
if memory_type not in get_args(MemoryType):
288289
msg = f"Invalid memory type: '{memory_type}', must be one of {get_args(MemoryType)}"
@@ -303,6 +304,7 @@ async def create_memory_item(
303304
"store": store,
304305
"category_ids": list(ctx.category_ids),
305306
"user": user_scope,
307+
"propagate": propagate,
306308
}
307309

308310
result = await self._run_workflow("patch_create", state)
@@ -320,6 +322,7 @@ async def update_memory_item(
320322
memory_content: str | None = None,
321323
memory_categories: list[str] | None = None,
322324
user: dict[str, Any] | None = None,
325+
propagate: bool = True,
323326
) -> dict[str, Any]:
324327
if all((memory_type is None, memory_content is None, memory_categories is None)):
325328
msg = "At least one of memory type, memory content, or memory categories is required for UPDATE operation"
@@ -344,6 +347,7 @@ async def update_memory_item(
344347
"store": store,
345348
"category_ids": list(ctx.category_ids),
346349
"user": user_scope,
350+
"propagate": propagate,
347351
}
348352

349353
result = await self._run_workflow("patch_update", state)
@@ -358,6 +362,7 @@ async def delete_memory_item(
358362
*,
359363
memory_id: str,
360364
user: dict[str, Any] | None = None,
365+
propagate: bool = True,
361366
) -> dict[str, Any]:
362367
ctx = self._get_context()
363368
store = self._get_database()
@@ -370,6 +375,7 @@ async def delete_memory_item(
370375
"store": store,
371376
"category_ids": list(ctx.category_ids),
372377
"user": user_scope,
378+
"propagate": propagate,
373379
}
374380

375381
result = await self._run_workflow("patch_delete", state)
@@ -504,6 +510,7 @@ async def _patch_create_memory_item(self, state: WorkflowState, step_context: An
504510
ctx = state["ctx"]
505511
store = state["store"]
506512
user = state["user"]
513+
propagate = state["propagate"]
507514
category_memory_updates: dict[str, tuple[Any, Any]] = {}
508515

509516
embed_payload = [memory_payload["content"]]
@@ -519,7 +526,8 @@ async def _patch_create_memory_item(self, state: WorkflowState, step_context: An
519526
mapped_cat_ids = self._map_category_names_to_ids(cat_names, ctx)
520527
for cid in mapped_cat_ids:
521528
store.category_item_repo.link_item_category(item.id, cid, user_data=dict(user or {}))
522-
category_memory_updates[cid] = (None, memory_payload["content"])
529+
if propagate:
530+
category_memory_updates[cid] = (None, memory_payload["content"])
523531

524532
state.update({
525533
"memory_item": item,
@@ -533,6 +541,7 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
533541
ctx = state["ctx"]
534542
store = state["store"]
535543
user = state["user"]
544+
propagate = state["propagate"]
536545
category_memory_updates: dict[str, tuple[Any, Any]] = {}
537546

538547
item = store.memory_item_repo.get_item(memory_id)
@@ -563,12 +572,14 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
563572
cats_to_add = set(mapped_new_cat_ids) - set(mapped_old_cat_ids)
564573
for cid in cats_to_remove:
565574
store.category_item_repo.unlink_item_category(memory_id, cid)
566-
category_memory_updates[cid] = (old_content, None)
575+
if propagate:
576+
category_memory_updates[cid] = (old_content, None)
567577
for cid in cats_to_add:
568578
store.category_item_repo.link_item_category(memory_id, cid, user_data=dict(user or {}))
569-
category_memory_updates[cid] = (None, item.summary)
579+
if propagate:
580+
category_memory_updates[cid] = (None, item.summary)
570581

571-
if memory_payload["content"]:
582+
if propagate and memory_payload["content"]:
572583
for cid in set(mapped_old_cat_ids) & set(mapped_new_cat_ids):
573584
category_memory_updates[cid] = (old_content, item.summary)
574585

@@ -581,15 +592,17 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
581592
async def _patch_delete_memory_item(self, state: WorkflowState, step_context: Any) -> WorkflowState:
582593
memory_id = state["memory_id"]
583594
store = state["store"]
595+
propagate = state["propagate"]
584596
category_memory_updates: dict[str, tuple[Any, Any]] = {}
585597

586598
item = store.memory_item_repo.get_item(memory_id)
587599
if not item:
588600
msg = f"Memory item with id {memory_id} not found"
589601
raise ValueError(msg)
590602
item_categories = store.category_item_repo.get_item_categories(memory_id)
591-
for cat in item_categories:
592-
category_memory_updates[cat.category_id] = (item.summary, None)
603+
if propagate:
604+
for cat in item_categories:
605+
category_memory_updates[cat.category_id] = (item.summary, None)
593606
store.memory_item_repo.delete_item(memory_id)
594607

595608
state.update({

src/memu/app/patch.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ async def create_memory_item(
4141
memory_content: str,
4242
memory_categories: list[str],
4343
user: dict[str, Any] | None = None,
44+
propagate: bool = True,
4445
) -> dict[str, Any]:
4546
if memory_type not in get_args(MemoryType):
4647
msg = f"Invalid memory type: '{memory_type}', must be one of {get_args(MemoryType)}"
@@ -61,6 +62,7 @@ async def create_memory_item(
6162
"store": store,
6263
"category_ids": list(ctx.category_ids),
6364
"user": user_scope,
65+
"propagate": propagate,
6466
}
6567

6668
result = await self._run_workflow("patch_create", state)
@@ -78,6 +80,7 @@ async def update_memory_item(
7880
memory_content: str | None = None,
7981
memory_categories: list[str] | None = None,
8082
user: dict[str, Any] | None = None,
83+
propagate: bool = True,
8184
) -> dict[str, Any]:
8285
if all((memory_type is None, memory_content is None, memory_categories is None)):
8386
msg = "At least one of memory type, memory content, or memory categories is required for UPDATE operation"
@@ -102,6 +105,7 @@ async def update_memory_item(
102105
"store": store,
103106
"category_ids": list(ctx.category_ids),
104107
"user": user_scope,
108+
"propagate": propagate,
105109
}
106110

107111
result = await self._run_workflow("patch_update", state)
@@ -116,6 +120,7 @@ async def delete_memory_item(
116120
*,
117121
memory_id: str,
118122
user: dict[str, Any] | None = None,
123+
propagate: bool = True,
119124
) -> dict[str, Any]:
120125
ctx = self._get_context()
121126
store = self._get_database()
@@ -128,6 +133,7 @@ async def delete_memory_item(
128133
"store": store,
129134
"category_ids": list(ctx.category_ids),
130135
"user": user_scope,
136+
"propagate": propagate,
131137
}
132138

133139
result = await self._run_workflow("patch_delete", state)
@@ -257,6 +263,7 @@ async def _patch_create_memory_item(self, state: WorkflowState, step_context: An
257263
ctx = state["ctx"]
258264
store = state["store"]
259265
user = state["user"]
266+
propagate = state["propagate"]
260267
category_memory_updates: dict[str, tuple[Any, Any]] = {}
261268

262269
embed_payload = [memory_payload["content"]]
@@ -272,7 +279,8 @@ async def _patch_create_memory_item(self, state: WorkflowState, step_context: An
272279
mapped_cat_ids = self._map_category_names_to_ids(cat_names, ctx)
273280
for cid in mapped_cat_ids:
274281
store.category_item_repo.link_item_category(item.id, cid, user_data=dict(user or {}))
275-
category_memory_updates[cid] = (None, memory_payload["content"])
282+
if propagate:
283+
category_memory_updates[cid] = (None, memory_payload["content"])
276284

277285
state.update({
278286
"memory_item": item,
@@ -286,6 +294,7 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
286294
ctx = state["ctx"]
287295
store = state["store"]
288296
user = state["user"]
297+
propagate = state["propagate"]
289298
category_memory_updates: dict[str, tuple[Any, Any]] = {}
290299

291300
item = store.memory_item_repo.get_item(memory_id)
@@ -316,12 +325,14 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
316325
cats_to_add = set(mapped_new_cat_ids) - set(mapped_old_cat_ids)
317326
for cid in cats_to_remove:
318327
store.category_item_repo.unlink_item_category(memory_id, cid)
319-
category_memory_updates[cid] = (old_content, None)
328+
if propagate:
329+
category_memory_updates[cid] = (old_content, None)
320330
for cid in cats_to_add:
321331
store.category_item_repo.link_item_category(memory_id, cid, user_data=dict(user or {}))
322-
category_memory_updates[cid] = (None, item.summary)
332+
if propagate:
333+
category_memory_updates[cid] = (None, item.summary)
323334

324-
if memory_payload["content"]:
335+
if propagate and memory_payload["content"]:
325336
for cid in set(mapped_old_cat_ids) & set(mapped_new_cat_ids):
326337
category_memory_updates[cid] = (old_content, item.summary)
327338

@@ -334,15 +345,17 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
334345
async def _patch_delete_memory_item(self, state: WorkflowState, step_context: Any) -> WorkflowState:
335346
memory_id = state["memory_id"]
336347
store = state["store"]
348+
propagate = state["propagate"]
337349
category_memory_updates: dict[str, tuple[Any, Any]] = {}
338350

339351
item = store.memory_item_repo.get_item(memory_id)
340352
if not item:
341353
msg = f"Memory item with id {memory_id} not found"
342354
raise ValueError(msg)
343355
item_categories = store.category_item_repo.get_item_categories(memory_id)
344-
for cat in item_categories:
345-
category_memory_updates[cat.category_id] = (item.summary, None)
356+
if propagate:
357+
for cat in item_categories:
358+
category_memory_updates[cat.category_id] = (item.summary, None)
346359
store.memory_item_repo.delete_item(memory_id)
347360

348361
state.update({

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)