@@ -119,7 +119,7 @@ def put(
119119 metadata : CheckpointMetadata ,
120120 new_versions : ChannelVersions ,
121121 ) -> RunnableConfig :
122- """Store only the latest checkpoint."""
122+ """Store only the latest checkpoint and clean up old blobs ."""
123123 configurable = config ["configurable" ].copy ()
124124 thread_id = configurable .pop ("thread_id" )
125125 checkpoint_ns = configurable .pop ("checkpoint_ns" )
@@ -146,6 +146,9 @@ def put(
146146 if all (key in metadata for key in ["source" , "step" ]):
147147 checkpoint_data ["source" ] = metadata ["source" ]
148148 checkpoint_data ["step" ] = metadata ["step" ]
149+
150+ # Note: Need to keep track of the current versions to keep
151+ current_channel_versions = new_versions .copy ()
149152
150153 self .checkpoints_index .load (
151154 [checkpoint_data ],
@@ -155,6 +158,38 @@ def put(
155158 )
156159 ],
157160 )
161+
162+ # Before storing the new blobs, clean up old ones that won't be needed
163+ # - Get a list of all blob keys for this thread_id and checkpoint_ns
164+ # - Then delete the ones that aren't in new_versions
165+ cleanup_pipeline = self ._redis .json ().pipeline (transaction = False )
166+
167+ # Get all blob keys for this thread/namespace
168+ blob_key_pattern = ShallowRedisSaver ._make_shallow_redis_checkpoint_blob_key_pattern (
169+ thread_id , checkpoint_ns
170+ )
171+ existing_blob_keys = self ._redis .keys (blob_key_pattern )
172+
173+ # Process each existing blob key to determine if it should be kept or deleted
174+ if existing_blob_keys :
175+ for blob_key in existing_blob_keys :
176+ key_parts = blob_key .decode ().split (REDIS_KEY_SEPARATOR )
177+ # The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
178+ if len (key_parts ) >= 5 :
179+ channel = key_parts [3 ]
180+ version = key_parts [4 ]
181+
182+ # Only keep the blob if it's referenced by the current versions
183+ if (channel in current_channel_versions and
184+ current_channel_versions [channel ] == version ):
185+ # This is a current version, keep it
186+ continue
187+ else :
188+ # This is an old version, delete it
189+ cleanup_pipeline .delete (blob_key )
190+
191+ # Execute the cleanup
192+ cleanup_pipeline .execute ()
158193
159194 # Store blob values
160195 blobs = self ._dump_blobs (
@@ -408,7 +443,7 @@ def put_writes(
408443 task_id : str ,
409444 task_path : str = "" ,
410445 ) -> None :
411- """Store intermediate writes linked to a checkpoint.
446+ """Store intermediate writes linked to a checkpoint and clean up old writes .
412447
413448 Args:
414449 config: Configuration of the related checkpoint.
@@ -436,6 +471,30 @@ def put_writes(
436471 "blob" : blob ,
437472 }
438473 writes_objects .append (write_obj )
474+
475+ # First clean up old writes for this thread and namespace if they're for a different checkpoint_id
476+ cleanup_pipeline = self ._redis .json ().pipeline (transaction = False )
477+
478+ # Get all writes keys for this thread/namespace
479+ writes_key_pattern = ShallowRedisSaver ._make_shallow_redis_checkpoint_writes_key_pattern (
480+ thread_id , checkpoint_ns
481+ )
482+ existing_writes_keys = self ._redis .keys (writes_key_pattern )
483+
484+ # Process each existing writes key to determine if it should be kept or deleted
485+ if existing_writes_keys :
486+ for write_key in existing_writes_keys :
487+ key_parts = write_key .decode ().split (REDIS_KEY_SEPARATOR )
488+ # The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
489+ if len (key_parts ) >= 5 :
490+ key_checkpoint_id = key_parts [3 ]
491+
492+ # If the write is for a different checkpoint_id, delete it
493+ if key_checkpoint_id != checkpoint_id :
494+ cleanup_pipeline .delete (write_key )
495+
496+ # Execute the cleanup
497+ cleanup_pipeline .execute ()
439498
440499 # For each write, check existence and then perform appropriate operation
441500 with self ._redis .json ().pipeline (transaction = False ) as pipeline :
@@ -470,18 +529,25 @@ def _dump_blobs(
470529 values : dict [str , Any ],
471530 versions : ChannelVersions ,
472531 ) -> List [Tuple [str , dict [str , Any ]]]:
532+ """Convert blob data for Redis storage.
533+
534+ In the shallow implementation, we use the version in the key to allow
535+ storing multiple versions without conflicts and to facilitate cleanup.
536+ """
473537 if not versions :
474538 return []
475539
476540 return [
477541 (
478- ShallowRedisSaver ._make_shallow_redis_checkpoint_blob_key (
479- thread_id , checkpoint_ns , k
542+ # Use the base Redis checkpoint blob key to include version, enabling version tracking
543+ BaseRedisSaver ._make_redis_checkpoint_blob_key (
544+ thread_id , checkpoint_ns , k , ver
480545 ),
481546 {
482547 "thread_id" : thread_id ,
483548 "checkpoint_ns" : checkpoint_ns ,
484549 "channel" : k ,
550+ "version" : ver , # Include version in the data as well
485551 "type" : (
486552 self ._get_type_and_blob (values [k ])[0 ]
487553 if k in values
@@ -581,12 +647,24 @@ def _load_pending_sends(
581647
582648 @staticmethod
583649 def _make_shallow_redis_checkpoint_key (thread_id : str , checkpoint_ns : str ) -> str :
650+ """Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
584651 return REDIS_KEY_SEPARATOR .join ([CHECKPOINT_PREFIX , thread_id , checkpoint_ns ])
585652
586653 @staticmethod
587654 def _make_shallow_redis_checkpoint_blob_key (
588655 thread_id : str , checkpoint_ns : str , channel : str
589656 ) -> str :
657+ """Create a key for a blob in a shallow checkpoint."""
590658 return REDIS_KEY_SEPARATOR .join (
591659 [CHECKPOINT_BLOB_PREFIX , thread_id , checkpoint_ns , channel ]
592660 )
661+
662+ @staticmethod
663+ def _make_shallow_redis_checkpoint_blob_key_pattern (thread_id : str , checkpoint_ns : str ) -> str :
664+ """Create a pattern to match all blob keys for a thread and namespace."""
665+ return REDIS_KEY_SEPARATOR .join ([CHECKPOINT_BLOB_PREFIX , thread_id , checkpoint_ns ]) + ":*"
666+
667+ @staticmethod
668+ def _make_shallow_redis_checkpoint_writes_key_pattern (thread_id : str , checkpoint_ns : str ) -> str :
669+ """Create a pattern to match all writes keys for a thread and namespace."""
670+ return REDIS_KEY_SEPARATOR .join ([CHECKPOINT_WRITE_PREFIX , thread_id , checkpoint_ns ]) + ":*"
0 commit comments