@@ -119,7 +119,7 @@ def put(
119
119
metadata : CheckpointMetadata ,
120
120
new_versions : ChannelVersions ,
121
121
) -> RunnableConfig :
122
- """Store only the latest checkpoint."""
122
+ """Store only the latest checkpoint and clean up old blobs ."""
123
123
configurable = config ["configurable" ].copy ()
124
124
thread_id = configurable .pop ("thread_id" )
125
125
checkpoint_ns = configurable .pop ("checkpoint_ns" )
@@ -146,6 +146,9 @@ def put(
146
146
if all (key in metadata for key in ["source" , "step" ]):
147
147
checkpoint_data ["source" ] = metadata ["source" ]
148
148
checkpoint_data ["step" ] = metadata ["step" ]
149
+
150
+ # Note: Need to keep track of the current versions to keep
151
+ current_channel_versions = new_versions .copy ()
149
152
150
153
self .checkpoints_index .load (
151
154
[checkpoint_data ],
@@ -155,6 +158,38 @@ def put(
155
158
)
156
159
],
157
160
)
161
+
162
+ # Before storing the new blobs, clean up old ones that won't be needed
163
+ # - Get a list of all blob keys for this thread_id and checkpoint_ns
164
+ # - Then delete the ones that aren't in new_versions
165
+ cleanup_pipeline = self ._redis .json ().pipeline (transaction = False )
166
+
167
+ # Get all blob keys for this thread/namespace
168
+ blob_key_pattern = ShallowRedisSaver ._make_shallow_redis_checkpoint_blob_key_pattern (
169
+ thread_id , checkpoint_ns
170
+ )
171
+ existing_blob_keys = self ._redis .keys (blob_key_pattern )
172
+
173
+ # Process each existing blob key to determine if it should be kept or deleted
174
+ if existing_blob_keys :
175
+ for blob_key in existing_blob_keys :
176
+ key_parts = blob_key .decode ().split (REDIS_KEY_SEPARATOR )
177
+ # The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
178
+ if len (key_parts ) >= 5 :
179
+ channel = key_parts [3 ]
180
+ version = key_parts [4 ]
181
+
182
+ # Only keep the blob if it's referenced by the current versions
183
+ if (channel in current_channel_versions and
184
+ current_channel_versions [channel ] == version ):
185
+ # This is a current version, keep it
186
+ continue
187
+ else :
188
+ # This is an old version, delete it
189
+ cleanup_pipeline .delete (blob_key )
190
+
191
+ # Execute the cleanup
192
+ cleanup_pipeline .execute ()
158
193
159
194
# Store blob values
160
195
blobs = self ._dump_blobs (
@@ -408,7 +443,7 @@ def put_writes(
408
443
task_id : str ,
409
444
task_path : str = "" ,
410
445
) -> None :
411
- """Store intermediate writes linked to a checkpoint.
446
+ """Store intermediate writes linked to a checkpoint and clean up old writes .
412
447
413
448
Args:
414
449
config: Configuration of the related checkpoint.
@@ -436,6 +471,30 @@ def put_writes(
436
471
"blob" : blob ,
437
472
}
438
473
writes_objects .append (write_obj )
474
+
475
+ # First clean up old writes for this thread and namespace if they're for a different checkpoint_id
476
+ cleanup_pipeline = self ._redis .json ().pipeline (transaction = False )
477
+
478
+ # Get all writes keys for this thread/namespace
479
+ writes_key_pattern = ShallowRedisSaver ._make_shallow_redis_checkpoint_writes_key_pattern (
480
+ thread_id , checkpoint_ns
481
+ )
482
+ existing_writes_keys = self ._redis .keys (writes_key_pattern )
483
+
484
+ # Process each existing writes key to determine if it should be kept or deleted
485
+ if existing_writes_keys :
486
+ for write_key in existing_writes_keys :
487
+ key_parts = write_key .decode ().split (REDIS_KEY_SEPARATOR )
488
+ # The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
489
+ if len (key_parts ) >= 5 :
490
+ key_checkpoint_id = key_parts [3 ]
491
+
492
+ # If the write is for a different checkpoint_id, delete it
493
+ if key_checkpoint_id != checkpoint_id :
494
+ cleanup_pipeline .delete (write_key )
495
+
496
+ # Execute the cleanup
497
+ cleanup_pipeline .execute ()
439
498
440
499
# For each write, check existence and then perform appropriate operation
441
500
with self ._redis .json ().pipeline (transaction = False ) as pipeline :
@@ -470,18 +529,25 @@ def _dump_blobs(
470
529
values : dict [str , Any ],
471
530
versions : ChannelVersions ,
472
531
) -> List [Tuple [str , dict [str , Any ]]]:
532
+ """Convert blob data for Redis storage.
533
+
534
+ In the shallow implementation, we use the version in the key to allow
535
+ storing multiple versions without conflicts and to facilitate cleanup.
536
+ """
473
537
if not versions :
474
538
return []
475
539
476
540
return [
477
541
(
478
- ShallowRedisSaver ._make_shallow_redis_checkpoint_blob_key (
479
- thread_id , checkpoint_ns , k
542
+ # Use the base Redis checkpoint blob key to include version, enabling version tracking
543
+ BaseRedisSaver ._make_redis_checkpoint_blob_key (
544
+ thread_id , checkpoint_ns , k , ver
480
545
),
481
546
{
482
547
"thread_id" : thread_id ,
483
548
"checkpoint_ns" : checkpoint_ns ,
484
549
"channel" : k ,
550
+ "version" : ver , # Include version in the data as well
485
551
"type" : (
486
552
self ._get_type_and_blob (values [k ])[0 ]
487
553
if k in values
@@ -581,12 +647,24 @@ def _load_pending_sends(
581
647
582
648
@staticmethod
583
649
def _make_shallow_redis_checkpoint_key (thread_id : str , checkpoint_ns : str ) -> str :
650
+ """Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
584
651
return REDIS_KEY_SEPARATOR .join ([CHECKPOINT_PREFIX , thread_id , checkpoint_ns ])
585
652
586
653
@staticmethod
587
654
def _make_shallow_redis_checkpoint_blob_key (
588
655
thread_id : str , checkpoint_ns : str , channel : str
589
656
) -> str :
657
+ """Create a key for a blob in a shallow checkpoint."""
590
658
return REDIS_KEY_SEPARATOR .join (
591
659
[CHECKPOINT_BLOB_PREFIX , thread_id , checkpoint_ns , channel ]
592
660
)
661
+
662
+ @staticmethod
663
+ def _make_shallow_redis_checkpoint_blob_key_pattern (thread_id : str , checkpoint_ns : str ) -> str :
664
+ """Create a pattern to match all blob keys for a thread and namespace."""
665
+ return REDIS_KEY_SEPARATOR .join ([CHECKPOINT_BLOB_PREFIX , thread_id , checkpoint_ns ]) + ":*"
666
+
667
+ @staticmethod
668
+ def _make_shallow_redis_checkpoint_writes_key_pattern (thread_id : str , checkpoint_ns : str ) -> str :
669
+ """Create a pattern to match all writes keys for a thread and namespace."""
670
+ return REDIS_KEY_SEPARATOR .join ([CHECKPOINT_WRITE_PREFIX , thread_id , checkpoint_ns ]) + ":*"
0 commit comments