diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 34538bf..65d72ff 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -33,6 +33,7 @@ EMPTY_ID_SENTINEL, from_storage_safe_id, from_storage_safe_str, + safely_decode, to_storage_safe_id, to_storage_safe_str, ) @@ -212,12 +213,14 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: # Get the blob keys blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*" blob_keys = await self._redis.keys(blob_key_pattern) - blob_keys = [key.decode() for key in blob_keys] + # Use safely_decode to handle both string and bytes responses + blob_keys = [safely_decode(key) for key in blob_keys] # Also get checkpoint write keys that should have the same TTL write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*" write_keys = await self._redis.keys(write_key_pattern) - write_keys = [key.decode() for key in write_keys] + # Use safely_decode to handle both string and bytes responses + write_keys = [safely_decode(key) for key in write_keys] # Apply TTL to checkpoint, blob keys, and write keys ttl_minutes = self.ttl_config.get("default_ttl") @@ -895,9 +898,11 @@ async def _aload_pending_writes( None, ) matching_keys = await self._redis.keys(pattern=writes_key) + # Use safely_decode to handle both string and bytes responses + decoded_keys = [safely_decode(key) for key in matching_keys] parsed_keys = [ - BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode()) - for key in matching_keys + BaseRedisSaver._parse_redis_checkpoint_writes_key(key) + for key in decoded_keys ] pending_writes = BaseRedisSaver._load_writes( self.serde, diff --git a/langgraph/checkpoint/redis/ashallow.py b/langgraph/checkpoint/redis/ashallow.py index da99885..58548c5 100644 --- a/langgraph/checkpoint/redis/ashallow.py +++ b/langgraph/checkpoint/redis/ashallow.py @@ -34,6 +34,7 @@ REDIS_KEY_SEPARATOR, BaseRedisSaver, ) +from langgraph.checkpoint.redis.util import safely_decode SCHEMAS = [ { @@ -252,7 +253,9 @@ async def aput( # Process each existing blob key to determine if it should be kept or deleted if existing_blob_keys: for blob_key in existing_blob_keys: - key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR) + # Use safely_decode to handle both string and bytes responses + decoded_key = safely_decode(blob_key) + key_parts = decoded_key.split(REDIS_KEY_SEPARATOR) # The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version if len(key_parts) >= 5: channel = key_parts[3] @@ -428,7 +431,8 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: ) ) blob_keys = await self._redis.keys(blob_key_pattern) - blob_keys = [key.decode() for key in blob_keys] + # Use safely_decode to handle both string and bytes responses + blob_keys = [safely_decode(key) for key in blob_keys] # Apply TTL ttl_minutes = self.ttl_config.get("default_ttl") @@ -554,7 +558,9 @@ async def aput_writes( # Process each existing writes key to determine if it should be kept or deleted if existing_writes_keys: for write_key in existing_writes_keys: - key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR) + # Use safely_decode to handle both string and bytes responses + decoded_key = safely_decode(write_key) + key_parts = decoded_key.split(REDIS_KEY_SEPARATOR) # The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx if len(key_parts) >= 5: key_checkpoint_id = key_parts[3] @@ -700,9 +706,11 @@ async def _aload_pending_writes( thread_id, checkpoint_ns, checkpoint_id, "*", None ) matching_keys = await self._redis.keys(pattern=writes_key) + # Use safely_decode to handle both string and bytes responses + decoded_keys = [safely_decode(key) for key in matching_keys] parsed_keys = [ - BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode()) - for key in matching_keys + BaseRedisSaver._parse_redis_checkpoint_writes_key(key) + for key in decoded_keys ] pending_writes = BaseRedisSaver._load_writes( self.serde, diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index f38f549..f227a33 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -18,6 +18,7 @@ from langgraph.checkpoint.serde.types import ChannelProtocol from langgraph.checkpoint.redis.util import ( + safely_decode, to_storage_safe_id, to_storage_safe_str, ) @@ -509,9 +510,12 @@ def _load_pending_writes( # Cast the result to List[bytes] to help type checker matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment] + # Use safely_decode to handle both string and bytes responses + decoded_keys = [safely_decode(key) for key in matching_keys] + parsed_keys = [ - BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode()) - for key in matching_keys + BaseRedisSaver._parse_redis_checkpoint_writes_key(key) + for key in decoded_keys ] pending_writes = BaseRedisSaver._load_writes( self.serde, @@ -541,6 +545,9 @@ def _load_writes( @staticmethod def _parse_redis_checkpoint_writes_key(redis_key: str) -> dict: + # Ensure redis_key is a string + redis_key = safely_decode(redis_key) + parts = redis_key.split(REDIS_KEY_SEPARATOR) # Ensure we have at least 6 parts if len(parts) < 6: diff --git a/langgraph/checkpoint/redis/shallow.py b/langgraph/checkpoint/redis/shallow.py index 93fc4c6..5190729 100644 --- a/langgraph/checkpoint/redis/shallow.py +++ b/langgraph/checkpoint/redis/shallow.py @@ -26,6 +26,7 @@ REDIS_KEY_SEPARATOR, BaseRedisSaver, ) +from langgraph.checkpoint.redis.util import safely_decode SCHEMAS = [ { @@ -179,7 +180,9 @@ def put( # Process each existing blob key to determine if it should be kept or deleted if existing_blob_keys: for blob_key in existing_blob_keys: - key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR) + # Use safely_decode to handle both string and bytes responses + decoded_key = safely_decode(blob_key) + key_parts = decoded_key.split(REDIS_KEY_SEPARATOR) # The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version if len(key_parts) >= 5: channel = key_parts[3] @@ -387,7 +390,10 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: thread_id, checkpoint_ns ) ) - blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)] + # Use safely_decode to handle both string and bytes responses + blob_keys = [ + safely_decode(key) for key in self._redis.keys(blob_key_pattern) + ] # Apply TTL self._apply_ttl_to_keys(checkpoint_key, blob_keys) @@ -524,7 +530,9 @@ def put_writes( # Process each existing writes key to determine if it should be kept or deleted if existing_writes_keys: for write_key in existing_writes_keys: - key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR) + # Use safely_decode to handle both string and bytes responses + decoded_key = safely_decode(write_key) + key_parts = decoded_key.split(REDIS_KEY_SEPARATOR) # The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx if len(key_parts) >= 5: key_checkpoint_id = key_parts[3] diff --git a/langgraph/checkpoint/redis/util.py b/langgraph/checkpoint/redis/util.py index 2489b4a..c4087a3 100644 --- a/langgraph/checkpoint/redis/util.py +++ b/langgraph/checkpoint/redis/util.py @@ -5,8 +5,14 @@ that is lexicographically sortable. Typically, checkpoints that need sentinel values are from the first run of the graph, so this should generally be correct. + +This module also includes utility functions for safely handling Redis responses, +including handling bytes vs string responses depending on how the Redis client is +configured with decode_responses. """ +from typing import Any, Dict, List, Optional, Set, Tuple, Union + EMPTY_STRING_SENTINEL = "__empty__" EMPTY_ID_SENTINEL = "00000000-0000-0000-0000-000000000000" @@ -81,3 +87,46 @@ def from_storage_safe_id(value: str) -> str: return "" else: return value + + +def safely_decode(obj: Any) -> Any: + """ + Safely decode Redis responses, handling both string and bytes types. + + This is especially useful when working with Redis clients configured with + different decode_responses settings. It recursively processes nested + data structures (dicts, lists, tuples, sets). + + Based on RedisVL's convert_bytes function (redisvl.redis.utils.convert_bytes) + but implemented directly to avoid runtime import issues and ensure consistent + behavior with sets and other data structures. See PR #34 and referenced + implementation: https://github.com/redis/redis-vl-python/blob/9f22a9ad4c2166af6462b007833b456448714dd9/redisvl/redis/utils.py#L20 + + Args: + obj: The object to decode. Can be a string, bytes, or a nested structure + containing strings/bytes (dict, list, tuple, set). + + Returns: + The decoded object with all bytes converted to strings. + """ + if obj is None: + return None + elif isinstance(obj, bytes): + try: + return obj.decode("utf-8") + except UnicodeDecodeError: + # If decoding fails, return the original bytes + return obj + elif isinstance(obj, str): + return obj + elif isinstance(obj, dict): + return {safely_decode(k): safely_decode(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [safely_decode(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(safely_decode(item) for item in obj) + elif isinstance(obj, set): + return {safely_decode(item) for item in obj} + else: + # For other types (int, float, bool, etc.), return as is + return obj diff --git a/tests/test_decode_responses.py b/tests/test_decode_responses.py new file mode 100644 index 0000000..92e3c0b --- /dev/null +++ b/tests/test_decode_responses.py @@ -0,0 +1,148 @@ +"""Tests for Redis key decoding functionality.""" + +import os +import time +import uuid +from typing import Any, Dict, Optional + +import pytest +from redis import Redis + +from langgraph.checkpoint.redis.util import safely_decode + + +def test_safely_decode_basic_types(): + """Test safely_decode function with basic type inputs.""" + # Test with bytes + assert safely_decode(b"test") == "test" + + # Test with string + assert safely_decode("test") == "test" + + # Test with None + assert safely_decode(None) is None + + # Test with other types + assert safely_decode(123) == 123 + assert safely_decode(1.23) == 1.23 + assert safely_decode(True) is True + + +def test_safely_decode_nested_structures(): + """Test safely_decode function with nested data structures.""" + # Test with dictionary + assert safely_decode({b"key": b"value"}) == {"key": "value"} + assert safely_decode({b"key1": b"value1", "key2": 123}) == { + "key1": "value1", + "key2": 123, + } + + # Test with nested dictionary + nested_dict = {b"outer": {b"inner": b"value"}} + assert safely_decode(nested_dict) == {"outer": {"inner": "value"}} + + # Test with list + assert safely_decode([b"item1", b"item2"]) == ["item1", "item2"] + + # Test with tuple + assert safely_decode((b"item1", b"item2")) == ("item1", "item2") + + # Test with set + decoded_set = safely_decode({b"item1", b"item2"}) + assert isinstance(decoded_set, set) + assert "item1" in decoded_set + assert "item2" in decoded_set + + # Test with complex nested structure + complex_struct = { + b"key1": [b"list_item1", {b"nested_key": b"nested_value"}], + b"key2": (b"tuple_item", 123), + b"key3": {b"set_item1", b"set_item2"}, + } + decoded = safely_decode(complex_struct) + assert decoded["key1"][0] == "list_item1" + assert decoded["key1"][1]["nested_key"] == "nested_value" + assert decoded["key2"][0] == "tuple_item" + assert decoded["key2"][1] == 123 + assert isinstance(decoded["key3"], set) + assert "set_item1" in decoded["key3"] + assert "set_item2" in decoded["key3"] + + +@pytest.mark.parametrize("decode_responses", [True, False]) +def test_safely_decode_with_redis(decode_responses: bool, redis_url): + """Test safely_decode function with actual Redis responses using TestContainers.""" + r = Redis.from_url(redis_url, decode_responses=decode_responses) + + try: + # Clean up before test to ensure a clean state + r.delete("test:string") + r.delete("test:hash") + r.delete("test:list") + r.delete("test:set") + + # Set up test data + r.set("test:string", "value") + r.hset("test:hash", mapping={"field1": "value1", "field2": "value2"}) + r.rpush("test:list", "item1", "item2", "item3") + r.sadd("test:set", "member1", "member2") + + # Test string value + string_val = r.get("test:string") + decoded_string = safely_decode(string_val) + assert decoded_string == "value" + + # Test hash value + hash_val = r.hgetall("test:hash") + decoded_hash = safely_decode(hash_val) + assert decoded_hash == {"field1": "value1", "field2": "value2"} + + # Test list value + list_val = r.lrange("test:list", 0, -1) + decoded_list = safely_decode(list_val) + assert decoded_list == ["item1", "item2", "item3"] + + # Test set value + set_val = r.smembers("test:set") + decoded_set = safely_decode(set_val) + assert isinstance(decoded_set, set) + assert "member1" in decoded_set + assert "member2" in decoded_set + + # Test key fetching + keys = r.keys("test:*") + decoded_keys = safely_decode(keys) + assert sorted(decoded_keys) == sorted( + ["test:string", "test:hash", "test:list", "test:set"] + ) + + finally: + # Clean up after test + r.delete("test:string") + r.delete("test:hash") + r.delete("test:list") + r.delete("test:set") + r.close() + + +def test_safely_decode_unicode_error_handling(): + """Test safely_decode function with invalid UTF-8 bytes.""" + # Create bytes that will cause UnicodeDecodeError + invalid_utf8 = b"\xff\xfe\xfd" + + # Should return the original bytes if it can't be decoded + result = safely_decode(invalid_utf8) + assert result == invalid_utf8 + + # Test with mixed valid and invalid in a complex structure + mixed = { + b"valid": b"This is valid UTF-8", + b"invalid": invalid_utf8, + b"nested": [b"valid", invalid_utf8], + } + + result = safely_decode(mixed) + assert result["valid"] == "This is valid UTF-8" + assert result["invalid"] == invalid_utf8 + assert result["nested"][0] == "valid" + assert result["nested"][1] == invalid_utf8