@@ -283,6 +283,7 @@ async def create_memory_item(
283283 memory_content : str ,
284284 memory_categories : list [str ],
285285 user : dict [str , Any ] | None = None ,
286+ propagate : bool = True ,
286287 ) -> dict [str , Any ]:
287288 if memory_type not in get_args (MemoryType ):
288289 msg = f"Invalid memory type: '{ memory_type } ', must be one of { get_args (MemoryType )} "
@@ -303,6 +304,7 @@ async def create_memory_item(
303304 "store" : store ,
304305 "category_ids" : list (ctx .category_ids ),
305306 "user" : user_scope ,
307+ "propagate" : propagate ,
306308 }
307309
308310 result = await self ._run_workflow ("patch_create" , state )
@@ -320,6 +322,7 @@ async def update_memory_item(
320322 memory_content : str | None = None ,
321323 memory_categories : list [str ] | None = None ,
322324 user : dict [str , Any ] | None = None ,
325+ propagate : bool = True ,
323326 ) -> dict [str , Any ]:
324327 if all ((memory_type is None , memory_content is None , memory_categories is None )):
325328 msg = "At least one of memory type, memory content, or memory categories is required for UPDATE operation"
@@ -344,6 +347,7 @@ async def update_memory_item(
344347 "store" : store ,
345348 "category_ids" : list (ctx .category_ids ),
346349 "user" : user_scope ,
350+ "propagate" : propagate ,
347351 }
348352
349353 result = await self ._run_workflow ("patch_update" , state )
@@ -358,6 +362,7 @@ async def delete_memory_item(
358362 * ,
359363 memory_id : str ,
360364 user : dict [str , Any ] | None = None ,
365+ propagate : bool = True ,
361366 ) -> dict [str , Any ]:
362367 ctx = self ._get_context ()
363368 store = self ._get_database ()
@@ -370,6 +375,7 @@ async def delete_memory_item(
370375 "store" : store ,
371376 "category_ids" : list (ctx .category_ids ),
372377 "user" : user_scope ,
378+ "propagate" : propagate ,
373379 }
374380
375381 result = await self ._run_workflow ("patch_delete" , state )
@@ -504,6 +510,7 @@ async def _patch_create_memory_item(self, state: WorkflowState, step_context: An
504510 ctx = state ["ctx" ]
505511 store = state ["store" ]
506512 user = state ["user" ]
513+ propagate = state ["propagate" ]
507514 category_memory_updates : dict [str , tuple [Any , Any ]] = {}
508515
509516 embed_payload = [memory_payload ["content" ]]
@@ -519,7 +526,8 @@ async def _patch_create_memory_item(self, state: WorkflowState, step_context: An
519526 mapped_cat_ids = self ._map_category_names_to_ids (cat_names , ctx )
520527 for cid in mapped_cat_ids :
521528 store .category_item_repo .link_item_category (item .id , cid , user_data = dict (user or {}))
522- category_memory_updates [cid ] = (None , memory_payload ["content" ])
529+ if propagate :
530+ category_memory_updates [cid ] = (None , memory_payload ["content" ])
523531
524532 state .update ({
525533 "memory_item" : item ,
@@ -533,6 +541,7 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
533541 ctx = state ["ctx" ]
534542 store = state ["store" ]
535543 user = state ["user" ]
544+ propagate = state ["propagate" ]
536545 category_memory_updates : dict [str , tuple [Any , Any ]] = {}
537546
538547 item = store .memory_item_repo .get_item (memory_id )
@@ -563,12 +572,14 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
563572 cats_to_add = set (mapped_new_cat_ids ) - set (mapped_old_cat_ids )
564573 for cid in cats_to_remove :
565574 store .category_item_repo .unlink_item_category (memory_id , cid )
566- category_memory_updates [cid ] = (old_content , None )
575+ if propagate :
576+ category_memory_updates [cid ] = (old_content , None )
567577 for cid in cats_to_add :
568578 store .category_item_repo .link_item_category (memory_id , cid , user_data = dict (user or {}))
569- category_memory_updates [cid ] = (None , item .summary )
579+ if propagate :
580+ category_memory_updates [cid ] = (None , item .summary )
570581
571- if memory_payload ["content" ]:
582+ if propagate and memory_payload ["content" ]:
572583 for cid in set (mapped_old_cat_ids ) & set (mapped_new_cat_ids ):
573584 category_memory_updates [cid ] = (old_content , item .summary )
574585
@@ -581,15 +592,17 @@ async def _patch_update_memory_item(self, state: WorkflowState, step_context: An
581592 async def _patch_delete_memory_item (self , state : WorkflowState , step_context : Any ) -> WorkflowState :
582593 memory_id = state ["memory_id" ]
583594 store = state ["store" ]
595+ propagate = state ["propagate" ]
584596 category_memory_updates : dict [str , tuple [Any , Any ]] = {}
585597
586598 item = store .memory_item_repo .get_item (memory_id )
587599 if not item :
588600 msg = f"Memory item with id { memory_id } not found"
589601 raise ValueError (msg )
590602 item_categories = store .category_item_repo .get_item_categories (memory_id )
591- for cat in item_categories :
592- category_memory_updates [cat .category_id ] = (item .summary , None )
603+ if propagate :
604+ for cat in item_categories :
605+ category_memory_updates [cat .category_id ] = (item .summary , None )
593606 store .memory_item_repo .delete_item (memory_id )
594607
595608 state .update ({
0 commit comments