Skip to content

Commit 1c051ec

Browse files
committed
fix: cleanup blobs and writes for shallow classes
1 parent ea9391d commit 1c051ec

File tree

4 files changed

+70
-33
lines changed

4 files changed

+70
-33
lines changed

langgraph/checkpoint/redis/ashallow.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import json
77
import os
88
from contextlib import asynccontextmanager
9-
from functools import partial
109
from types import TracebackType
1110
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast
1211

@@ -25,7 +24,6 @@
2524
from redisvl.index import AsyncSearchIndex
2625
from redisvl.query import FilterQuery
2726
from redisvl.query.filter import Num, Tag
28-
from redisvl.redis.connection import RedisConnectionFactory
2927

3028
from langgraph.checkpoint.redis.base import (
3129
CHECKPOINT_BLOB_PREFIX,
@@ -34,6 +32,10 @@
3432
REDIS_KEY_SEPARATOR,
3533
BaseRedisSaver,
3634
)
35+
from langgraph.checkpoint.redis.util import (
36+
to_storage_safe_id,
37+
to_storage_safe_str,
38+
)
3739

3840
SCHEMAS = [
3941
{
@@ -794,26 +796,38 @@ def put_writes(
794796
@staticmethod
795797
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
796798
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
797-
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
799+
return REDIS_KEY_SEPARATOR.join(
800+
[
801+
CHECKPOINT_PREFIX,
802+
str(to_storage_safe_id(thread_id)),
803+
to_storage_safe_str(checkpoint_ns),
804+
]
805+
)
798806

799807
@staticmethod
800808
def _make_shallow_redis_checkpoint_blob_key_pattern(
801809
thread_id: str, checkpoint_ns: str
802810
) -> str:
803811
"""Create a pattern to match all blob keys for a thread and namespace."""
804-
return (
805-
REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns])
806-
+ ":*"
812+
return REDIS_KEY_SEPARATOR.join(
813+
[
814+
CHECKPOINT_BLOB_PREFIX,
815+
str(to_storage_safe_id(thread_id)),
816+
to_storage_safe_str(checkpoint_ns),
817+
"*",
818+
]
807819
)
808820

809821
@staticmethod
810822
def _make_shallow_redis_checkpoint_writes_key_pattern(
811823
thread_id: str, checkpoint_ns: str
812824
) -> str:
813825
"""Create a pattern to match all writes keys for a thread and namespace."""
814-
return (
815-
REDIS_KEY_SEPARATOR.join(
816-
[CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns]
817-
)
818-
+ ":*"
826+
return REDIS_KEY_SEPARATOR.join(
827+
[
828+
CHECKPOINT_WRITE_PREFIX,
829+
str(to_storage_safe_id(thread_id)),
830+
to_storage_safe_str(checkpoint_ns),
831+
"*",
832+
]
819833
)

langgraph/checkpoint/redis/shallow.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
REDIS_KEY_SEPARATOR,
2727
BaseRedisSaver,
2828
)
29+
from langgraph.checkpoint.redis.util import (
30+
to_storage_safe_id,
31+
to_storage_safe_str,
32+
)
2933

3034
SCHEMAS = [
3135
{
@@ -688,35 +692,38 @@ def _load_pending_sends(
688692
@staticmethod
689693
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
690694
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
691-
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
692-
693-
@staticmethod
694-
def _make_shallow_redis_checkpoint_blob_key(
695-
thread_id: str, checkpoint_ns: str, channel: str
696-
) -> str:
697-
"""Create a key for a blob in a shallow checkpoint."""
698695
return REDIS_KEY_SEPARATOR.join(
699-
[CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns, channel]
696+
[
697+
CHECKPOINT_PREFIX,
698+
str(to_storage_safe_id(thread_id)),
699+
to_storage_safe_str(checkpoint_ns),
700+
]
700701
)
701702

702703
@staticmethod
703704
def _make_shallow_redis_checkpoint_blob_key_pattern(
704705
thread_id: str, checkpoint_ns: str
705706
) -> str:
706707
"""Create a pattern to match all blob keys for a thread and namespace."""
707-
return (
708-
REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns])
709-
+ ":*"
708+
return REDIS_KEY_SEPARATOR.join(
709+
[
710+
CHECKPOINT_BLOB_PREFIX,
711+
str(to_storage_safe_id(thread_id)),
712+
to_storage_safe_str(checkpoint_ns),
713+
"*",
714+
]
710715
)
711716

712717
@staticmethod
713718
def _make_shallow_redis_checkpoint_writes_key_pattern(
714719
thread_id: str, checkpoint_ns: str
715720
) -> str:
716721
"""Create a pattern to match all writes keys for a thread and namespace."""
717-
return (
718-
REDIS_KEY_SEPARATOR.join(
719-
[CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns]
720-
)
721-
+ ":*"
722+
return REDIS_KEY_SEPARATOR.join(
723+
[
724+
CHECKPOINT_WRITE_PREFIX,
725+
str(to_storage_safe_id(thread_id)),
726+
to_storage_safe_str(checkpoint_ns),
727+
"*",
728+
]
722729
)

tests/test_shallow_async.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from redis.exceptions import ConnectionError as RedisConnectionError
1313

1414
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
15+
from langgraph.checkpoint.redis.base import CHECKPOINT_BLOB_PREFIX
1516

1617

1718
@pytest.fixture
@@ -96,7 +97,10 @@ async def test_only_latest_checkpoint(
9697
}
9798
)
9899
checkpoint_1 = test_data["checkpoints"][0]
99-
await saver.aput(config_1, checkpoint_1, test_data["metadata"][0], {})
100+
channel_versions_1 = {"test_channel": "1"}
101+
await saver.aput(
102+
config_1, checkpoint_1, test_data["metadata"][0], channel_versions_1
103+
)
100104

101105
# Create second checkpoint
102106
config_2 = RunnableConfig(
@@ -108,13 +112,19 @@ async def test_only_latest_checkpoint(
108112
}
109113
)
110114
checkpoint_2 = test_data["checkpoints"][1]
111-
await saver.aput(config_2, checkpoint_2, test_data["metadata"][1], {})
115+
channel_versions_2 = {"test_channel": "2"}
116+
await saver.aput(
117+
config_2, checkpoint_2, test_data["metadata"][1], channel_versions_2
118+
)
112119

113-
# Verify only latest checkpoint exists
120+
# Verify only latest checkpoint and blobs exists
114121
results = [c async for c in saver.alist(None)]
115122
assert len(results) == 1
116123
assert results[0].config["configurable"]["checkpoint_id"] == checkpoint_2["id"]
117124

125+
blobs = list(await saver._redis.keys(CHECKPOINT_BLOB_PREFIX + ":*"))
126+
assert len(blobs) == 1
127+
118128

119129
@pytest.mark.asyncio
120130
@pytest.mark.parametrize(

tests/test_shallow_sync.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from redis import Redis
1313
from redis.exceptions import ConnectionError as RedisConnectionError
1414

15+
from langgraph.checkpoint.redis.base import CHECKPOINT_BLOB_PREFIX
1516
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver
1617

1718

@@ -102,7 +103,8 @@ def test_only_latest_checkpoint(
102103
}
103104
}
104105
checkpoint_1 = test_data["checkpoints"][0]
105-
saver.put(config_1, checkpoint_1, test_data["metadata"][0], {})
106+
channel_versions_1 = {"test_channel": "1"}
107+
saver.put(config_1, checkpoint_1, test_data["metadata"][0], channel_versions_1)
106108

107109
# Create second checkpoint
108110
config_2 = {
@@ -112,13 +114,17 @@ def test_only_latest_checkpoint(
112114
}
113115
}
114116
checkpoint_2 = test_data["checkpoints"][1]
115-
saver.put(config_2, checkpoint_2, test_data["metadata"][1], {})
117+
channel_versions_2 = {"test_channel": "2"}
118+
saver.put(config_2, checkpoint_2, test_data["metadata"][1], channel_versions_2)
116119

117-
# Verify only latest checkpoint exists
120+
# Verify only latest checkpoint and blobs exists
118121
results = list(saver.list(None))
119122
assert len(results) == 1
120123
assert results[0].config["configurable"]["checkpoint_id"] == checkpoint_2["id"]
121124

125+
blobs = list(saver._redis.keys(CHECKPOINT_BLOB_PREFIX + ":*"))
126+
assert len(blobs) == 1
127+
122128

123129
@pytest.mark.parametrize(
124130
"query, expected_count",

0 commit comments

Comments
 (0)