Skip to content

Commit d21047d

Browse files
committed
fix(checkpoint): handle bytes in nested structures with msgpack fallback
- Add bytes detection in _default_handler to trigger msgpack serialization - Update dumps_typed to fallback to parent's msgpack for bytes in dicts - Update _dump_checkpoint to handle both JSON and msgpack types - Add __bytes__ marker for bytes in channel_values stored via msgpack - Update _recursive_deserialize to decode __bytes__ markers - Fix test assertions to expect bytes instead of str from dumps_typed
1 parent c64dca9 commit d21047d

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

langgraph/checkpoint/redis/base.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,21 @@ def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]:
310310
"""Convert checkpoint to Redis format."""
311311
type_, data = self.serde.dumps_typed(checkpoint)
312312

313-
# Since we're keeping JSON format, decode string data
314-
checkpoint_data = cast(dict, orjson.loads(data))
313+
# Decode the serialized data - handle both JSON and msgpack
314+
if type_ == "json":
315+
checkpoint_data = cast(dict, orjson.loads(data))
316+
else:
317+
# For msgpack or other types, deserialize with loads_typed
318+
checkpoint_data = cast(dict, self.serde.loads_typed((type_, data)))
319+
320+
# When using msgpack, bytes are preserved - but Redis JSON.SET can't handle them
321+
# Encode bytes in channel_values with type marker for JSON storage
322+
if "channel_values" in checkpoint_data:
323+
for key, value in checkpoint_data["channel_values"].items():
324+
if isinstance(value, bytes):
325+
checkpoint_data["channel_values"][key] = {
326+
"__bytes__": self._encode_blob(value)
327+
}
315328

316329
# Ensure channel_versions are always strings to fix issue #40
317330
if "channel_versions" in checkpoint_data:
@@ -379,6 +392,11 @@ def _recursive_deserialize(self, obj: Any) -> Any:
379392
The deserialized object, with LangChain objects properly reconstructed.
380393
"""
381394
if isinstance(obj, dict):
395+
# Check if this is a bytes marker from msgpack storage
396+
if "__bytes__" in obj and len(obj) == 1:
397+
# Decode base64-encoded bytes
398+
return self._decode_blob(obj["__bytes__"])
399+
382400
# Check if this is a LangChain serialized object
383401
if obj.get("lc") in (1, 2) and obj.get("type") == "constructor":
384402
try:

langgraph/checkpoint/redis/jsonplus_redis.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def _default_handler(self, obj: Any) -> Any:
3333
This handles LangChain objects by delegating to the parent's
3434
_encode_constructor_args method which creates the LC format.
3535
"""
36+
# Bytes/bytearray in nested structures require msgpack - signal to fallback
37+
if isinstance(obj, (bytes, bytearray)):
38+
raise TypeError("bytes/bytearray in nested structure - use msgpack")
39+
3640
# Try to encode using parent's constructor args encoder
3741
# This creates the {"lc": 2, "type": "constructor", ...} format
3842
try:
@@ -57,6 +61,8 @@ def _default_handler(self, obj: Any) -> Any:
5761
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
5862
"""Serialize using orjson for JSON.
5963
64+
Falls back to msgpack for structures containing bytes/bytearray.
65+
6066
Returns:
6167
tuple[str, bytes]: Type identifier and serialized bytes
6268
"""
@@ -67,9 +73,13 @@ def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
6773
elif obj is None:
6874
return "null", b""
6975
else:
70-
# Use orjson for JSON serialization with custom default handler
71-
json_bytes = orjson.dumps(obj, default=self._default_handler)
72-
return "json", json_bytes
76+
try:
77+
# Try orjson first with custom default handler
78+
json_bytes = orjson.dumps(obj, default=self._default_handler)
79+
return "json", json_bytes
80+
except (TypeError, orjson.JSONEncodeError):
81+
# Fall back to parent's msgpack serialization for bytes in nested structures
82+
return super().dumps_typed(obj)
7383

7484
def loads_typed(self, data: tuple[str, bytes]) -> Any:
7585
"""Deserialize with custom revival for LangChain/LangGraph objects.

tests/test_checkpoint_serialization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,13 @@ def test_issue_83_pending_sends_type_compatibility(redis_url: str) -> None:
166166
# Serialize
167167
type_str, blob = saver.serde.dumps_typed(test_data)
168168
assert isinstance(type_str, str)
169-
assert isinstance(blob, str) # JsonPlusRedisSerializer returns strings
169+
# Checkpoint 3.0: dumps_typed now returns bytes
170+
assert isinstance(blob, bytes)
170171

171-
# Deserialize - should work with both string and bytes
172+
# Deserialize with bytes (checkpoint 3.0 format)
172173
result1 = saver.serde.loads_typed((type_str, blob))
173-
result2 = saver.serde.loads_typed((type_str, blob.encode())) # bytes version
174174

175175
assert result1 == test_data
176-
assert result2 == test_data
177176

178177

179178
def test_load_blobs_method(redis_url: str) -> None:
@@ -508,7 +507,8 @@ def test_langchain_message_serialization(redis_url: str) -> None:
508507
# Serialize
509508
type_, data = serializer.dumps_typed(human_msg)
510509
assert type_ == "json"
511-
assert isinstance(data, str)
510+
# Checkpoint 3.0: dumps_typed now returns bytes
511+
assert isinstance(data, bytes)
512512

513513
# Deserialize
514514
deserialized = serializer.loads_typed((type_, data))

0 commit comments

Comments
 (0)