@@ -384,7 +384,24 @@ async def aput(
384384 metadata : CheckpointMetadata ,
385385 new_versions : ChannelVersions ,
386386 ) -> RunnableConfig :
387- """Store a checkpoint to Redis."""
387+ """Store a checkpoint to Redis with proper transaction handling.
388+
389+ This method ensures that all Redis operations are performed atomically
390+ using Redis transactions. In case of interruption (asyncio.CancelledError),
391+ the transaction will be aborted, ensuring consistency.
392+
393+ Args:
394+ config: The config to associate with the checkpoint
395+ checkpoint: The checkpoint data to store
396+ metadata: Additional metadata to save with the checkpoint
397+ new_versions: New channel versions as of this write
398+
399+ Returns:
400+ Updated configuration after storing the checkpoint
401+
402+ Raises:
403+ asyncio.CancelledError: If the operation is cancelled/interrupted
404+ """
388405 configurable = config ["configurable" ].copy ()
389406
390407 thread_id = configurable .pop ("thread_id" )
@@ -410,46 +427,63 @@ async def aput(
410427 }
411428 }
412429
413- # Store checkpoint data
414- checkpoint_data = {
415- "thread_id" : storage_safe_thread_id ,
416- "checkpoint_ns" : storage_safe_checkpoint_ns ,
417- "checkpoint_id" : storage_safe_checkpoint_id ,
418- "parent_checkpoint_id" : storage_safe_checkpoint_id ,
419- "checkpoint" : self ._dump_checkpoint (copy ),
420- "metadata" : self ._dump_metadata (metadata ),
421- }
422-
423- # store at top-level for filters in list()
424- if all (key in metadata for key in ["source" , "step" ]):
425- checkpoint_data ["source" ] = metadata ["source" ]
426- checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
427-
428- await self .checkpoints_index .load (
429- [checkpoint_data ],
430- keys = [
431- BaseRedisSaver ._make_redis_checkpoint_key (
432- storage_safe_thread_id ,
433- storage_safe_checkpoint_ns ,
434- storage_safe_checkpoint_id ,
435- )
436- ],
437- )
438-
439- # Store blob values
440- blobs = self ._dump_blobs (
441- storage_safe_thread_id ,
442- storage_safe_checkpoint_ns ,
443- copy .get ("channel_values" , {}),
444- new_versions ,
445- )
446-
447- if blobs :
448- # Unzip the list of tuples into separate lists for keys and data
449- keys , data = zip (* blobs )
450- await self .checkpoint_blobs_index .load (list (data ), keys = list (keys ))
451-
452- return next_config
430+ # Store checkpoint data with transaction handling
431+ try :
432+ # Create a pipeline with transaction=True for atomicity
433+ pipeline = self ._redis .pipeline (transaction = True )
434+
435+ # Store checkpoint data
436+ checkpoint_data = {
437+ "thread_id" : storage_safe_thread_id ,
438+ "checkpoint_ns" : storage_safe_checkpoint_ns ,
439+ "checkpoint_id" : storage_safe_checkpoint_id ,
440+ "parent_checkpoint_id" : storage_safe_checkpoint_id ,
441+ "checkpoint" : self ._dump_checkpoint (copy ),
442+ "metadata" : self ._dump_metadata (metadata ),
443+ }
444+
445+ # store at top-level for filters in list()
446+ if all (key in metadata for key in ["source" , "step" ]):
447+ checkpoint_data ["source" ] = metadata ["source" ]
448+ checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
449+
450+ # Prepare checkpoint key
451+ checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
452+ storage_safe_thread_id ,
453+ storage_safe_checkpoint_ns ,
454+ storage_safe_checkpoint_id ,
455+ )
456+
457+ # Add checkpoint data to Redis
458+ await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
459+
460+ # Store blob values
461+ blobs = self ._dump_blobs (
462+ storage_safe_thread_id ,
463+ storage_safe_checkpoint_ns ,
464+ copy .get ("channel_values" , {}),
465+ new_versions ,
466+ )
467+
468+ if blobs :
469+ # Add all blob operations to the pipeline
470+ for key , data in blobs :
471+ await pipeline .json ().set (key , "$" , data )
472+
473+ # Execute all operations atomically
474+ await pipeline .execute ()
475+
476+ return next_config
477+
478+ except asyncio .CancelledError :
479+ # Handle cancellation/interruption
480+ # Pipeline will be automatically discarded
481+ # Either all operations succeed or none do
482+ raise
483+
484+ except Exception as e :
485+ # Re-raise other exceptions
486+ raise e
453487
454488 async def aput_writes (
455489 self ,
@@ -458,14 +492,23 @@ async def aput_writes(
458492 task_id : str ,
459493 task_path : str = "" ,
460494 ) -> None :
461- """Store intermediate writes linked to a checkpoint using Redis JSON.
495+ """Store intermediate writes linked to a checkpoint using Redis JSON with transaction handling.
496+
497+ This method uses Redis pipeline with transaction=True to ensure atomicity of all
498+ write operations. In case of interruption, all operations will be aborted.
462499
463500 Args:
464501 config (RunnableConfig): Configuration of the related checkpoint.
465502 writes (List[Tuple[str, Any]]): List of writes to store.
466503 task_id (str): Identifier for the task creating the writes.
467504 task_path (str): Path of the task creating the writes.
505+
506+ Raises:
507+ asyncio.CancelledError: If the operation is cancelled/interrupted
468508 """
509+ if not writes :
510+ return
511+
469512 thread_id = config ["configurable" ]["thread_id" ]
470513 checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
471514 checkpoint_id = config ["configurable" ]["checkpoint_id" ]
@@ -487,7 +530,14 @@ async def aput_writes(
487530 }
488531 writes_objects .append (write_obj )
489532
533+ try :
534+ # Use a transaction pipeline for atomicity
535+ pipeline = self ._redis .pipeline (transaction = True )
536+
537+ # Determine if this is an upsert case
490538 upsert_case = all (w [0 ] in WRITES_IDX_MAP for w in writes )
539+
540+ # Add all write operations to the pipeline
491541 for write_obj in writes_objects :
492542 key = self ._make_redis_checkpoint_writes_key (
493543 thread_id ,
@@ -496,10 +546,36 @@ async def aput_writes(
496546 task_id ,
497547 write_obj ["idx" ], # type: ignore[arg-type]
498548 )
499- tx = partial (
500- _write_obj_tx , key = key , write_obj = write_obj , upsert_case = upsert_case
501- )
502- await self ._redis .transaction (tx , key )
549+
550+ if upsert_case :
551+ # For upsert case, we need to check if the key exists and update differently
552+ exists = await self ._redis .exists (key )
553+ if exists :
554+ # Update existing key
555+ await pipeline .json ().set (key , "$.channel" , write_obj ["channel" ])
556+ await pipeline .json ().set (key , "$.type" , write_obj ["type" ])
557+ await pipeline .json ().set (key , "$.blob" , write_obj ["blob" ])
558+ else :
559+ # Create new key
560+ await pipeline .json ().set (key , "$" , write_obj )
561+ else :
562+ # For non-upsert case, only set if key doesn't exist
563+ exists = await self ._redis .exists (key )
564+ if not exists :
565+ await pipeline .json ().set (key , "$" , write_obj )
566+
567+ # Execute all operations atomically
568+ await pipeline .execute ()
569+
570+ except asyncio .CancelledError :
571+ # Handle cancellation/interruption
572+ # Pipeline will be automatically discarded
573+ # Either all operations succeed or none do
574+ raise
575+
576+ except Exception as e :
577+ # Re-raise other exceptions
578+ raise e
503579
504580 def put_writes (
505581 self ,
0 commit comments