Skip to content

Commit d21ee55

Browse files
committed
fix: prevent blob and write accumulation in ShallowRedisSaver classes (#13)
Add cleanup logic to AsyncShallowRedisSaver and ShallowRedisSaver to delete old blobs and writes when storing new checkpoints. This prevents memory bloat when using shallow savers, which should only keep the latest checkpoint state. Add comprehensive test to verify the fix works correctly.
1 parent 2dbe143 commit d21ee55

File tree

3 files changed

+395
-24
lines changed

3 files changed

+395
-24
lines changed

langgraph/checkpoint/redis/ashallow.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ async def aput(
172172
metadata: CheckpointMetadata,
173173
new_versions: ChannelVersions,
174174
) -> RunnableConfig:
175-
"""Store only the latest checkpoint asynchronously."""
175+
"""Store only the latest checkpoint asynchronously and clean up old blobs."""
176176
configurable = config["configurable"].copy()
177177
thread_id = configurable.pop("thread_id")
178178
checkpoint_ns = configurable.pop("checkpoint_ns")
@@ -200,6 +200,10 @@ async def aput(
200200
checkpoint_data["source"] = metadata["source"]
201201
checkpoint_data["step"] = metadata["step"]
202202

203+
# Note: Need to keep track of the current versions to keep
204+
current_channel_versions = new_versions.copy()
205+
206+
# Store the new checkpoint, which replaces any existing one due to the shallow key
203207
await self.checkpoints_index.load(
204208
[checkpoint_data],
205209
keys=[
@@ -209,7 +213,39 @@ async def aput(
209213
],
210214
)
211215

212-
# Store blob values
216+
# Before storing the new blobs, clean up old ones that won't be needed
217+
# - Get a list of all blob keys for this thread_id and checkpoint_ns
218+
# - Then delete the ones that aren't in new_versions
219+
cleanup_pipeline = self._redis.pipeline()
220+
221+
# Get all blob keys for this thread/namespace
222+
blob_key_pattern = AsyncShallowRedisSaver._make_shallow_redis_checkpoint_blob_key_pattern(
223+
thread_id, checkpoint_ns
224+
)
225+
existing_blob_keys = await self._redis.keys(blob_key_pattern)
226+
227+
# Process each existing blob key to determine if it should be kept or deleted
228+
if existing_blob_keys:
229+
for blob_key in existing_blob_keys:
230+
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
231+
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
232+
if len(key_parts) >= 5:
233+
channel = key_parts[3]
234+
version = key_parts[4]
235+
236+
# Only keep the blob if it's referenced by the current versions
237+
if (channel in current_channel_versions and
238+
current_channel_versions[channel] == version):
239+
# This is a current version, keep it
240+
continue
241+
else:
242+
# This is an old version, delete it
243+
cleanup_pipeline.delete(blob_key)
244+
245+
# Execute the cleanup
246+
await cleanup_pipeline.execute()
247+
248+
# Store the new blob values
213249
blobs = self._dump_blobs(
214250
thread_id,
215251
checkpoint_ns,
@@ -385,7 +421,7 @@ async def aput_writes(
385421
task_id: str,
386422
task_path: str = "",
387423
) -> None:
388-
"""Store intermediate writes for the latest checkpoint.
424+
"""Store intermediate writes for the latest checkpoint and clean up old writes.
389425
390426
Args:
391427
config (RunnableConfig): Configuration of the related checkpoint.
@@ -413,23 +449,48 @@ async def aput_writes(
413449
"blob": blob,
414450
}
415451
writes_objects.append(write_obj)
416-
417-
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
418-
for write_obj in writes_objects:
419-
key = self._make_redis_checkpoint_writes_key(
420-
thread_id,
421-
checkpoint_ns,
422-
checkpoint_id,
423-
task_id,
424-
write_obj["idx"],
425-
)
426-
if upsert_case:
427-
tx = partial(_write_obj_tx, key=key, write_obj=write_obj)
428-
await self._redis.transaction(tx, key)
429-
else:
430-
# Unlike AsyncRedisSaver, the shallow implementation always overwrites
431-
# the full object, so we don't check for key existence here.
432-
await self._redis.json().set(key, "$", write_obj)
452+
453+
# First clean up old writes for this thread and namespace if they're for a different checkpoint_id
454+
cleanup_pipeline = self._redis.pipeline()
455+
456+
# Get all writes keys for this thread/namespace
457+
writes_key_pattern = AsyncShallowRedisSaver._make_shallow_redis_checkpoint_writes_key_pattern(
458+
thread_id, checkpoint_ns
459+
)
460+
existing_writes_keys = await self._redis.keys(writes_key_pattern)
461+
462+
# Process each existing writes key to determine if it should be kept or deleted
463+
if existing_writes_keys:
464+
for write_key in existing_writes_keys:
465+
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
466+
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
467+
if len(key_parts) >= 5:
468+
key_checkpoint_id = key_parts[3]
469+
470+
# If the write is for a different checkpoint_id, delete it
471+
if key_checkpoint_id != checkpoint_id:
472+
cleanup_pipeline.delete(write_key)
473+
474+
# Execute the cleanup
475+
await cleanup_pipeline.execute()
476+
477+
# Now store the new writes
478+
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
479+
for write_obj in writes_objects:
480+
key = self._make_redis_checkpoint_writes_key(
481+
thread_id,
482+
checkpoint_ns,
483+
checkpoint_id,
484+
task_id,
485+
write_obj["idx"],
486+
)
487+
if upsert_case:
488+
tx = partial(_write_obj_tx, key=key, write_obj=write_obj)
489+
await self._redis.transaction(tx, key)
490+
else:
491+
# Unlike AsyncRedisSaver, the shallow implementation always overwrites
492+
# the full object, so we don't check for key existence here.
493+
await self._redis.json().set(key, "$", write_obj)
433494

434495
async def aget_channel_values(
435496
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
@@ -622,4 +683,15 @@ def put_writes(
622683

623684
@staticmethod
624685
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
686+
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
625687
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
688+
689+
@staticmethod
690+
def _make_shallow_redis_checkpoint_blob_key_pattern(thread_id: str, checkpoint_ns: str) -> str:
691+
"""Create a pattern to match all blob keys for a thread and namespace."""
692+
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns]) + ":*"
693+
694+
@staticmethod
695+
def _make_shallow_redis_checkpoint_writes_key_pattern(thread_id: str, checkpoint_ns: str) -> str:
696+
"""Create a pattern to match all writes keys for a thread and namespace."""
697+
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns]) + ":*"

langgraph/checkpoint/redis/shallow.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def put(
119119
metadata: CheckpointMetadata,
120120
new_versions: ChannelVersions,
121121
) -> RunnableConfig:
122-
"""Store only the latest checkpoint."""
122+
"""Store only the latest checkpoint and clean up old blobs."""
123123
configurable = config["configurable"].copy()
124124
thread_id = configurable.pop("thread_id")
125125
checkpoint_ns = configurable.pop("checkpoint_ns")
@@ -146,6 +146,9 @@ def put(
146146
if all(key in metadata for key in ["source", "step"]):
147147
checkpoint_data["source"] = metadata["source"]
148148
checkpoint_data["step"] = metadata["step"]
149+
150+
# Note: Need to keep track of the current versions to keep
151+
current_channel_versions = new_versions.copy()
149152

150153
self.checkpoints_index.load(
151154
[checkpoint_data],
@@ -155,6 +158,38 @@ def put(
155158
)
156159
],
157160
)
161+
162+
# Before storing the new blobs, clean up old ones that won't be needed
163+
# - Get a list of all blob keys for this thread_id and checkpoint_ns
164+
# - Then delete the ones that aren't in new_versions
165+
cleanup_pipeline = self._redis.json().pipeline(transaction=False)
166+
167+
# Get all blob keys for this thread/namespace
168+
blob_key_pattern = ShallowRedisSaver._make_shallow_redis_checkpoint_blob_key_pattern(
169+
thread_id, checkpoint_ns
170+
)
171+
existing_blob_keys = self._redis.keys(blob_key_pattern)
172+
173+
# Process each existing blob key to determine if it should be kept or deleted
174+
if existing_blob_keys:
175+
for blob_key in existing_blob_keys:
176+
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
177+
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
178+
if len(key_parts) >= 5:
179+
channel = key_parts[3]
180+
version = key_parts[4]
181+
182+
# Only keep the blob if it's referenced by the current versions
183+
if (channel in current_channel_versions and
184+
current_channel_versions[channel] == version):
185+
# This is a current version, keep it
186+
continue
187+
else:
188+
# This is an old version, delete it
189+
cleanup_pipeline.delete(blob_key)
190+
191+
# Execute the cleanup
192+
cleanup_pipeline.execute()
158193

159194
# Store blob values
160195
blobs = self._dump_blobs(
@@ -408,7 +443,7 @@ def put_writes(
408443
task_id: str,
409444
task_path: str = "",
410445
) -> None:
411-
"""Store intermediate writes linked to a checkpoint.
446+
"""Store intermediate writes linked to a checkpoint and clean up old writes.
412447
413448
Args:
414449
config: Configuration of the related checkpoint.
@@ -436,6 +471,30 @@ def put_writes(
436471
"blob": blob,
437472
}
438473
writes_objects.append(write_obj)
474+
475+
# First clean up old writes for this thread and namespace if they're for a different checkpoint_id
476+
cleanup_pipeline = self._redis.json().pipeline(transaction=False)
477+
478+
# Get all writes keys for this thread/namespace
479+
writes_key_pattern = ShallowRedisSaver._make_shallow_redis_checkpoint_writes_key_pattern(
480+
thread_id, checkpoint_ns
481+
)
482+
existing_writes_keys = self._redis.keys(writes_key_pattern)
483+
484+
# Process each existing writes key to determine if it should be kept or deleted
485+
if existing_writes_keys:
486+
for write_key in existing_writes_keys:
487+
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
488+
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
489+
if len(key_parts) >= 5:
490+
key_checkpoint_id = key_parts[3]
491+
492+
# If the write is for a different checkpoint_id, delete it
493+
if key_checkpoint_id != checkpoint_id:
494+
cleanup_pipeline.delete(write_key)
495+
496+
# Execute the cleanup
497+
cleanup_pipeline.execute()
439498

440499
# For each write, check existence and then perform appropriate operation
441500
with self._redis.json().pipeline(transaction=False) as pipeline:
@@ -470,18 +529,25 @@ def _dump_blobs(
470529
values: dict[str, Any],
471530
versions: ChannelVersions,
472531
) -> List[Tuple[str, dict[str, Any]]]:
532+
"""Convert blob data for Redis storage.
533+
534+
In the shallow implementation, we use the version in the key to allow
535+
storing multiple versions without conflicts and to facilitate cleanup.
536+
"""
473537
if not versions:
474538
return []
475539

476540
return [
477541
(
478-
ShallowRedisSaver._make_shallow_redis_checkpoint_blob_key(
479-
thread_id, checkpoint_ns, k
542+
# Use the base Redis checkpoint blob key to include version, enabling version tracking
543+
BaseRedisSaver._make_redis_checkpoint_blob_key(
544+
thread_id, checkpoint_ns, k, ver
480545
),
481546
{
482547
"thread_id": thread_id,
483548
"checkpoint_ns": checkpoint_ns,
484549
"channel": k,
550+
"version": ver, # Include version in the data as well
485551
"type": (
486552
self._get_type_and_blob(values[k])[0]
487553
if k in values
@@ -581,12 +647,24 @@ def _load_pending_sends(
581647

582648
@staticmethod
583649
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
650+
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
584651
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
585652

586653
@staticmethod
587654
def _make_shallow_redis_checkpoint_blob_key(
588655
thread_id: str, checkpoint_ns: str, channel: str
589656
) -> str:
657+
"""Create a key for a blob in a shallow checkpoint."""
590658
return REDIS_KEY_SEPARATOR.join(
591659
[CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns, channel]
592660
)
661+
662+
@staticmethod
663+
def _make_shallow_redis_checkpoint_blob_key_pattern(thread_id: str, checkpoint_ns: str) -> str:
664+
"""Create a pattern to match all blob keys for a thread and namespace."""
665+
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns]) + ":*"
666+
667+
@staticmethod
668+
def _make_shallow_redis_checkpoint_writes_key_pattern(thread_id: str, checkpoint_ns: str) -> str:
669+
"""Create a pattern to match all writes keys for a thread and namespace."""
670+
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns]) + ":*"

0 commit comments

Comments
 (0)