Skip to content

Commit 5944791

Browse files
committed
fix: handle both bytes and string Redis keys when decode_responses=True
Resolves issue #24 where Redis key decoding would fail when a client was configured with decode_responses=True. Adds a new `safely_decode` utility function that handles both bytes and string keys.
1 parent 43f2b6b commit 5944791

File tree

7 files changed

+164
-18
lines changed

7 files changed

+164
-18
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
EMPTY_ID_SENTINEL,
2828
from_storage_safe_id,
2929
from_storage_safe_str,
30+
safely_decode,
3031
to_storage_safe_id,
3132
to_storage_safe_str,
3233
)
@@ -518,4 +519,5 @@ def _load_pending_sends(
518519
"BaseRedisSaver",
519520
"ShallowRedisSaver",
520521
"AsyncShallowRedisSaver",
522+
"safely_decode",
521523
]

langgraph/checkpoint/redis/aio.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
EMPTY_ID_SENTINEL,
3434
from_storage_safe_id,
3535
from_storage_safe_str,
36+
safely_decode,
3637
to_storage_safe_id,
3738
to_storage_safe_str,
3839
)
@@ -823,21 +824,28 @@ async def _aload_pending_writes(
823824
"*",
824825
None,
825826
)
827+
# The result from self._redis.keys() can vary based on client implementation
826828
matching_keys = await self._redis.keys(pattern=writes_key)
829+
827830
parsed_keys = [
828-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
831+
BaseRedisSaver._parse_redis_checkpoint_writes_key(safely_decode(key))
829832
for key in matching_keys
830833
]
834+
# Create key-parsed_key pairs and sort them
835+
pairs = [
836+
(key, parsed_key) for key, parsed_key in zip(matching_keys, parsed_keys)
837+
]
838+
sorted_pairs = sorted(pairs, key=lambda x: x[1]["idx"])
839+
840+
# Build the dictionary with the sorted pairs
831841
pending_writes = BaseRedisSaver._load_writes(
832842
self.serde,
833843
{
834844
(
835845
parsed_key["task_id"],
836846
parsed_key["idx"],
837847
): await self._redis.json().get(key)
838-
for key, parsed_key in sorted(
839-
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
840-
)
848+
for key, parsed_key in sorted_pairs
841849
},
842850
)
843851
return pending_writes

langgraph/checkpoint/redis/ashallow.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
REDIS_KEY_SEPARATOR,
3535
BaseRedisSaver,
3636
)
37+
from langgraph.checkpoint.redis.util import safely_decode
3738

3839
SCHEMAS = [
3940
{
@@ -246,7 +247,7 @@ async def aput(
246247
# Process each existing blob key to determine if it should be kept or deleted
247248
if existing_blob_keys:
248249
for blob_key in existing_blob_keys:
249-
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
250+
key_parts = safely_decode(blob_key).split(REDIS_KEY_SEPARATOR)
250251
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
251252
if len(key_parts) >= 5:
252253
channel = key_parts[3]
@@ -503,7 +504,7 @@ async def aput_writes(
503504
# Process each existing writes key to determine if it should be kept or deleted
504505
if existing_writes_keys:
505506
for write_key in existing_writes_keys:
506-
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
507+
key_parts = safely_decode(write_key).split(REDIS_KEY_SEPARATOR)
507508
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
508509
if len(key_parts) >= 5:
509510
key_checkpoint_id = key_parts[3]
@@ -648,21 +649,28 @@ async def _aload_pending_writes(
648649
writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
649650
thread_id, checkpoint_ns, checkpoint_id, "*", None
650651
)
652+
# The result from self._redis.keys() can vary based on client implementation
651653
matching_keys = await self._redis.keys(pattern=writes_key)
654+
652655
parsed_keys = [
653-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
656+
BaseRedisSaver._parse_redis_checkpoint_writes_key(safely_decode(key))
654657
for key in matching_keys
655658
]
659+
# Create key-parsed_key pairs and sort them
660+
pairs = [
661+
(key, parsed_key) for key, parsed_key in zip(matching_keys, parsed_keys)
662+
]
663+
sorted_pairs = sorted(pairs, key=lambda x: x[1]["idx"])
664+
665+
# Build the dictionary with the sorted pairs
656666
pending_writes = BaseRedisSaver._load_writes(
657667
self.serde,
658668
{
659669
(
660670
parsed_key["task_id"],
661671
parsed_key["idx"],
662672
): await self._redis.json().get(key)
663-
for key, parsed_key in sorted(
664-
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
665-
)
673+
for key, parsed_key in sorted_pairs
666674
},
667675
)
668676
return pending_writes

langgraph/checkpoint/redis/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langgraph.checkpoint.serde.types import ChannelProtocol
1919

2020
from langgraph.checkpoint.redis.util import (
21+
safely_decode,
2122
to_storage_safe_id,
2223
to_storage_safe_str,
2324
)
@@ -479,20 +480,24 @@ def _load_pending_writes(
479480
None,
480481
)
481482

482-
# Cast the result to List[bytes] to help type checker
483-
matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment]
483+
# The result from self._redis.keys() can vary based on client implementation
484+
matching_keys = self._redis.keys(pattern=writes_key)
484485

485486
parsed_keys = [
486-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
487+
BaseRedisSaver._parse_redis_checkpoint_writes_key(safely_decode(key))
487488
for key in matching_keys
488489
]
490+
# Create key-parsed_key pairs and sort them
491+
# Using type ignore because Redis client implementations can vary
492+
pairs = [(key, parsed_key) for key, parsed_key in zip(matching_keys, parsed_keys)] # type: ignore
493+
sorted_pairs = sorted(pairs, key=lambda x: x[1]["idx"])
494+
495+
# Build the dictionary with the sorted pairs
489496
pending_writes = BaseRedisSaver._load_writes(
490497
self.serde,
491498
{
492499
(parsed_key["task_id"], parsed_key["idx"]): self._redis.json().get(key)
493-
for key, parsed_key in sorted(
494-
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
495-
)
500+
for key, parsed_key in sorted_pairs
496501
},
497502
)
498503
return pending_writes

langgraph/checkpoint/redis/shallow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
REDIS_KEY_SEPARATOR,
2727
BaseRedisSaver,
2828
)
29+
from langgraph.checkpoint.redis.util import safely_decode
2930

3031
SCHEMAS = [
3132
{
@@ -175,7 +176,7 @@ def put(
175176
# Process each existing blob key to determine if it should be kept or deleted
176177
if existing_blob_keys:
177178
for blob_key in existing_blob_keys:
178-
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
179+
key_parts = safely_decode(blob_key).split(REDIS_KEY_SEPARATOR)
179180
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
180181
if len(key_parts) >= 5:
181182
channel = key_parts[3]
@@ -490,7 +491,7 @@ def put_writes(
490491
# Process each existing writes key to determine if it should be kept or deleted
491492
if existing_writes_keys:
492493
for write_key in existing_writes_keys:
493-
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
494+
key_parts = safely_decode(write_key).split(REDIS_KEY_SEPARATOR)
494495
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
495496
if len(key_parts) >= 5:
496497
key_checkpoint_id = key_parts[3]

langgraph/checkpoint/redis/util.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
generally be correct.
88
"""
99

10+
from typing import Union
11+
1012
EMPTY_STRING_SENTINEL = "__empty__"
1113
EMPTY_ID_SENTINEL = "00000000-0000-0000-0000-000000000000"
1214

@@ -81,3 +83,22 @@ def from_storage_safe_id(value: str) -> str:
8183
return ""
8284
else:
8385
return value
86+
87+
88+
def safely_decode(key: Union[bytes, str]) -> str:
89+
"""
90+
Safely decode a Redis key regardless of whether it's bytes or string.
91+
92+
This function handles both cases:
93+
- When Redis client is configured with decode_responses=False (returns bytes)
94+
- When Redis client is configured with decode_responses=True (returns strings)
95+
96+
Args:
97+
key: The Redis key, either bytes or string
98+
99+
Returns:
100+
The decoded key as a string
101+
"""
102+
if isinstance(key, bytes):
103+
return key.decode()
104+
return key

tests/test_decode_responses.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Test Redis key decoding functionality to ensure it works with both decode_responses=True and False."""
2+
3+
import pytest
4+
from redis import Redis
5+
from redisvl.redis.connection import RedisConnectionFactory
6+
7+
from langgraph.checkpoint.redis.base import BaseRedisSaver
8+
from langgraph.checkpoint.redis.util import safely_decode
9+
10+
11+
def test_safely_decode():
12+
"""Test the safely_decode function with both bytes and strings."""
13+
# Test with bytes
14+
assert safely_decode(b"test_key") == "test_key"
15+
16+
# Test with string
17+
assert safely_decode("test_key") == "test_key"
18+
19+
20+
@pytest.fixture
21+
def redis_client_decoded():
22+
"""Redis client with decode_responses=True."""
23+
client = Redis.from_url("redis://localhost:6379", decode_responses=True)
24+
yield client
25+
client.close()
26+
27+
28+
@pytest.fixture
29+
def redis_client_bytes():
30+
"""Redis client with decode_responses=False (default)."""
31+
client = Redis.from_url("redis://localhost:6379", decode_responses=False)
32+
yield client
33+
client.close()
34+
35+
36+
def test_redis_keys_with_decode_responses(redis_client_decoded, redis_client_bytes):
37+
"""Test that redis.keys() behaves as expected with different decode_responses settings."""
38+
# Generate a unique key prefix for this test
39+
test_key_prefix = "test_decode_responses_"
40+
41+
# Create some test keys
42+
for i in range(3):
43+
key = f"{test_key_prefix}{i}"
44+
redis_client_bytes.set(key, f"value{i}")
45+
46+
try:
47+
# Test with decode_responses=False (returns bytes)
48+
keys_bytes = redis_client_bytes.keys(f"{test_key_prefix}*")
49+
assert all(isinstance(k, bytes) for k in keys_bytes)
50+
51+
# Test with decode_responses=True (returns strings)
52+
keys_str = redis_client_decoded.keys(f"{test_key_prefix}*")
53+
assert all(isinstance(k, str) for k in keys_str)
54+
55+
# Test that our safely_decode function works with both
56+
decoded_bytes = [safely_decode(k) for k in keys_bytes]
57+
decoded_str = [safely_decode(k) for k in keys_str]
58+
59+
# Both should now be lists of strings
60+
assert all(isinstance(k, str) for k in decoded_bytes)
61+
assert all(isinstance(k, str) for k in decoded_str)
62+
63+
# Both should contain the same keys
64+
assert sorted(decoded_bytes) == sorted(decoded_str)
65+
66+
finally:
67+
# Clean up
68+
for i in range(3):
69+
redis_client_bytes.delete(f"{test_key_prefix}{i}")
70+
71+
72+
def test_parse_redis_key_with_different_clients(
73+
redis_client_decoded, redis_client_bytes
74+
):
75+
"""Test that our _parse_redis_checkpoint_writes_key method works correctly."""
76+
# Create a test key using the format expected by the parser
77+
from langgraph.checkpoint.redis.base import (
78+
CHECKPOINT_WRITE_PREFIX,
79+
REDIS_KEY_SEPARATOR,
80+
)
81+
82+
test_key = f"{CHECKPOINT_WRITE_PREFIX}{REDIS_KEY_SEPARATOR}thread1{REDIS_KEY_SEPARATOR}ns1{REDIS_KEY_SEPARATOR}cp1{REDIS_KEY_SEPARATOR}task1{REDIS_KEY_SEPARATOR}0"
83+
84+
# Test parsing with bytes key (as would come from decode_responses=False)
85+
bytes_key = test_key.encode()
86+
parsed_bytes = BaseRedisSaver._parse_redis_checkpoint_writes_key(
87+
safely_decode(bytes_key)
88+
)
89+
90+
# Test parsing with string key (as would come from decode_responses=True)
91+
parsed_str = BaseRedisSaver._parse_redis_checkpoint_writes_key(
92+
safely_decode(test_key)
93+
)
94+
95+
# Both should produce the same result
96+
assert parsed_bytes == parsed_str
97+
assert parsed_bytes["thread_id"] == "thread1"
98+
assert parsed_bytes["checkpoint_ns"] == "ns1"
99+
assert parsed_bytes["checkpoint_id"] == "cp1"
100+
assert parsed_bytes["task_id"] == "task1"
101+
assert parsed_bytes["idx"] == "0"

0 commit comments

Comments
 (0)