@@ -385,20 +385,20 @@ async def aput(
385385 new_versions : ChannelVersions ,
386386 ) -> RunnableConfig :
387387 """Store a checkpoint to Redis with proper transaction handling.
388-
388+
389389 This method ensures that all Redis operations are performed atomically
390390 using Redis transactions. In case of interruption (asyncio.CancelledError),
391391 the transaction will be aborted, ensuring consistency.
392-
392+
393393 Args:
394394 config: The config to associate with the checkpoint
395395 checkpoint: The checkpoint data to store
396396 metadata: Additional metadata to save with the checkpoint
397397 new_versions: New channel versions as of this write
398-
398+
399399 Returns:
400400 Updated configuration after storing the checkpoint
401-
401+
402402 Raises:
403403 asyncio.CancelledError: If the operation is cancelled/interrupted
404404 """
@@ -431,7 +431,7 @@ async def aput(
431431 try :
432432 # Create a pipeline with transaction=True for atomicity
433433 pipeline = self ._redis .pipeline (transaction = True )
434-
434+
435435 # Store checkpoint data
436436 checkpoint_data = {
437437 "thread_id" : storage_safe_thread_id ,
@@ -441,46 +441,46 @@ async def aput(
441441 "checkpoint" : self ._dump_checkpoint (copy ),
442442 "metadata" : self ._dump_metadata (metadata ),
443443 }
444-
444+
445445 # store at top-level for filters in list()
446446 if all (key in metadata for key in ["source" , "step" ]):
447447 checkpoint_data ["source" ] = metadata ["source" ]
448448 checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
449-
449+
450450 # Prepare checkpoint key
451451 checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
452452 storage_safe_thread_id ,
453453 storage_safe_checkpoint_ns ,
454454 storage_safe_checkpoint_id ,
455455 )
456-
456+
457457 # Add checkpoint data to Redis
458458 await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
459-
459+
460460 # Store blob values
461461 blobs = self ._dump_blobs (
462462 storage_safe_thread_id ,
463463 storage_safe_checkpoint_ns ,
464464 copy .get ("channel_values" , {}),
465465 new_versions ,
466466 )
467-
467+
468468 if blobs :
469469 # Add all blob operations to the pipeline
470470 for key , data in blobs :
471471 await pipeline .json ().set (key , "$" , data )
472-
472+
473473 # Execute all operations atomically
474474 await pipeline .execute ()
475-
475+
476476 return next_config
477-
477+
478478 except asyncio .CancelledError :
479479 # Handle cancellation/interruption
480480 # Pipeline will be automatically discarded
481481 # Either all operations succeed or none do
482482 raise
483-
483+
484484 except Exception as e :
485485 # Re-raise other exceptions
486486 raise e
@@ -502,13 +502,13 @@ async def aput_writes(
502502 writes (List[Tuple[str, Any]]): List of writes to store.
503503 task_id (str): Identifier for the task creating the writes.
504504 task_path (str): Path of the task creating the writes.
505-
505+
506506 Raises:
507507 asyncio.CancelledError: If the operation is cancelled/interrupted
508508 """
509509 if not writes :
510510 return
511-
511+
512512 thread_id = config ["configurable" ]["thread_id" ]
513513 checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
514514 checkpoint_id = config ["configurable" ]["checkpoint_id" ]
@@ -533,10 +533,10 @@ async def aput_writes(
533533 try :
534534 # Use a transaction pipeline for atomicity
535535 pipeline = self ._redis .pipeline (transaction = True )
536-
536+
537537 # Determine if this is an upsert case
538538 upsert_case = all (w [0 ] in WRITES_IDX_MAP for w in writes )
539-
539+
540540 # Add all write operations to the pipeline
541541 for write_obj in writes_objects :
542542 key = self ._make_redis_checkpoint_writes_key (
@@ -546,13 +546,15 @@ async def aput_writes(
546546 task_id ,
547547 write_obj ["idx" ], # type: ignore[arg-type]
548548 )
549-
549+
550550 if upsert_case :
551551 # For upsert case, we need to check if the key exists and update differently
552552 exists = await self ._redis .exists (key )
553553 if exists :
554554 # Update existing key
555- await pipeline .json ().set (key , "$.channel" , write_obj ["channel" ])
555+ await pipeline .json ().set (
556+ key , "$.channel" , write_obj ["channel" ]
557+ )
556558 await pipeline .json ().set (key , "$.type" , write_obj ["type" ])
557559 await pipeline .json ().set (key , "$.blob" , write_obj ["blob" ])
558560 else :
@@ -563,16 +565,16 @@ async def aput_writes(
563565 exists = await self ._redis .exists (key )
564566 if not exists :
565567 await pipeline .json ().set (key , "$" , write_obj )
566-
568+
567569 # Execute all operations atomically
568570 await pipeline .execute ()
569-
571+
570572 except asyncio .CancelledError :
571- # Handle cancellation/interruption
573+ # Handle cancellation/interruption
572574 # Pipeline will be automatically discarded
573575 # Either all operations succeed or none do
574576 raise
575-
577+
576578 except Exception as e :
577579 # Re-raise other exceptions
578580 raise e
0 commit comments