Skip to content

Commit bb8f07f

Browse files
Merge branch 'main' into fix/JsonPlusRedisSerializer
2 parents ed81536 + 4919e23 commit bb8f07f

File tree

3 files changed

+416
-9
lines changed

3 files changed

+416
-9
lines changed

langgraph/checkpoint/redis/jsonplus_redis.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,9 @@ def dumps(self, obj: Any) -> bytes:
5050
except ImportError:
5151
pass
5252

53-
try:
54-
# Fast path: Use orjson for JSON-serializable objects
55-
return orjson.dumps(obj)
56-
except TypeError:
57-
# Complex objects (Send, etc.) need parent's msgpack serialization
58-
return super().dumps(obj)
53+
# Use orjson with default handler for LangChain objects
54+
# The _default method from parent class handles LangChain serialization
55+
return orjson.dumps(obj, default=self._default)
5956

6057
def loads(self, data: bytes) -> Any:
6158
"""Use orjson for JSON parsing with reviver support, fallback to parent for msgpack data."""
@@ -64,9 +61,15 @@ def loads(self, data: bytes) -> Any:
6461
parsed = orjson.loads(data)
6562
# Apply reviver for LangChain objects (lc format)
6663
return self._revive_if_needed(parsed)
67-
except orjson.JSONDecodeError:
68-
# Fallback: Parent handles msgpack and other formats
69-
return super().loads(data)
64+
except (orjson.JSONDecodeError, TypeError):
65+
# Fallback: Parent handles msgpack and other formats via loads_typed
66+
# Attempt to detect type and use loads_typed
67+
try:
68+
# Try loading as msgpack via parent's loads_typed
69+
return super().loads_typed(("msgpack", data))
70+
except Exception:
71+
# If that fails, try loading as json string
72+
return super().loads_typed(("json", data))
7073

7174
def _revive_if_needed(self, obj: Any) -> Any:
7275
"""Recursively apply reviver to handle LangChain serialized objects.
@@ -126,6 +129,7 @@ def dumps_typed(self, obj: Any) -> tuple[str, str]: # type: ignore[override]
126129
if isinstance(obj, (bytes, bytearray)):
127130
return "base64", base64.b64encode(obj).decode("utf-8")
128131
else:
132+
# All objects should be JSON-serializable (LangChain objects are pre-serialized)
129133
return "json", self.dumps(obj).decode("utf-8")
130134

131135
def loads_typed(self, data: tuple[str, Union[str, bytes]]) -> Any:

test_jsonplus_redis_serializer.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""Standalone test to verify the JsonPlusRedisSerializer fix works.
2+
3+
This can be run directly without pytest infrastructure:
4+
python test_fix_standalone.py
5+
"""
6+
7+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
8+
from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer
9+
10+
11+
def test_human_message_serialization():
12+
"""Test that HumanMessage can be serialized without TypeError."""
13+
print("Testing HumanMessage serialization...")
14+
15+
serializer = JsonPlusRedisSerializer()
16+
msg = HumanMessage(content="What is the weather?", id="msg-1")
17+
18+
try:
19+
# This would raise TypeError before the fix
20+
serialized = serializer.dumps(msg)
21+
print(f" ✓ Serialized to {len(serialized)} bytes")
22+
23+
# Deserialize
24+
deserialized = serializer.loads(serialized)
25+
assert isinstance(deserialized, HumanMessage)
26+
assert deserialized.content == "What is the weather?"
27+
assert deserialized.id == "msg-1"
28+
print(f" ✓ Deserialized correctly: {deserialized.content}")
29+
30+
return True
31+
except TypeError as e:
32+
print(f" ✗ FAILED: {e}")
33+
return False
34+
35+
36+
def test_all_message_types():
37+
"""Test all LangChain message types."""
38+
print("\nTesting all message types...")
39+
40+
serializer = JsonPlusRedisSerializer()
41+
messages = [
42+
HumanMessage(content="Hello"),
43+
AIMessage(content="Hi!"),
44+
SystemMessage(content="System prompt"),
45+
]
46+
47+
for msg in messages:
48+
try:
49+
serialized = serializer.dumps(msg)
50+
deserialized = serializer.loads(serialized)
51+
assert type(deserialized) == type(msg)
52+
print(f" ✓ {type(msg).__name__} works")
53+
except Exception as e:
54+
print(f" ✗ {type(msg).__name__} FAILED: {e}")
55+
return False
56+
57+
return True
58+
59+
60+
def test_message_list():
61+
"""Test list of messages (common pattern in LangGraph)."""
62+
print("\nTesting message list...")
63+
64+
serializer = JsonPlusRedisSerializer()
65+
messages = [
66+
HumanMessage(content="Question 1"),
67+
AIMessage(content="Answer 1"),
68+
HumanMessage(content="Question 2"),
69+
]
70+
71+
try:
72+
serialized = serializer.dumps(messages)
73+
deserialized = serializer.loads(serialized)
74+
75+
assert isinstance(deserialized, list)
76+
assert len(deserialized) == 3
77+
assert all(isinstance(m, (HumanMessage, AIMessage)) for m in deserialized)
78+
print(f" ✓ List of {len(deserialized)} messages works")
79+
80+
return True
81+
except Exception as e:
82+
print(f" ✗ FAILED: {e}")
83+
return False
84+
85+
86+
def test_nested_structure():
87+
"""Test nested structure with messages (realistic LangGraph state)."""
88+
print("\nTesting nested structure with messages...")
89+
90+
serializer = JsonPlusRedisSerializer()
91+
state = {
92+
"messages": [
93+
HumanMessage(content="Query"),
94+
AIMessage(content="Response"),
95+
],
96+
"step": 1,
97+
}
98+
99+
try:
100+
serialized = serializer.dumps(state)
101+
deserialized = serializer.loads(serialized)
102+
103+
assert "messages" in deserialized
104+
assert len(deserialized["messages"]) == 2
105+
assert isinstance(deserialized["messages"][0], HumanMessage)
106+
assert isinstance(deserialized["messages"][1], AIMessage)
107+
print(f" ✓ Nested structure works")
108+
109+
return True
110+
except Exception as e:
111+
print(f" ✗ FAILED: {e}")
112+
return False
113+
114+
115+
def test_dumps_typed():
116+
"""Test dumps_typed (what checkpointer actually uses)."""
117+
print("\nTesting dumps_typed...")
118+
119+
serializer = JsonPlusRedisSerializer()
120+
msg = HumanMessage(content="Test", id="test-123")
121+
122+
try:
123+
type_str, blob = serializer.dumps_typed(msg)
124+
assert type_str == "json"
125+
assert isinstance(blob, str)
126+
print(f" ✓ dumps_typed returns: type='{type_str}', blob={len(blob)} chars")
127+
128+
deserialized = serializer.loads_typed((type_str, blob))
129+
assert isinstance(deserialized, HumanMessage)
130+
assert deserialized.content == "Test"
131+
print(f" ✓ loads_typed works correctly")
132+
133+
return True
134+
except Exception as e:
135+
print(f" ✗ FAILED: {e}")
136+
return False
137+
138+
139+
def test_backwards_compatibility():
140+
"""Test that regular objects still work."""
141+
print("\nTesting backwards compatibility...")
142+
143+
serializer = JsonPlusRedisSerializer()
144+
test_cases = [
145+
("string", "hello"),
146+
("int", 42),
147+
("dict", {"key": "value"}),
148+
("list", [1, 2, 3]),
149+
]
150+
151+
for name, obj in test_cases:
152+
try:
153+
serialized = serializer.dumps(obj)
154+
deserialized = serializer.loads(serialized)
155+
assert deserialized == obj
156+
print(f" ✓ {name} works")
157+
except Exception as e:
158+
print(f" ✗ {name} FAILED: {e}")
159+
return False
160+
161+
return True
162+
163+
164+
def main():
165+
"""Run all tests."""
166+
print("=" * 70)
167+
print("JsonPlusRedisSerializer Fix Validation")
168+
print("=" * 70)
169+
170+
tests = [
171+
test_human_message_serialization,
172+
test_all_message_types,
173+
test_message_list,
174+
test_nested_structure,
175+
test_dumps_typed,
176+
test_backwards_compatibility,
177+
]
178+
179+
results = []
180+
for test in tests:
181+
results.append(test())
182+
183+
print("\n" + "=" * 70)
184+
print(f"Results: {sum(results)}/{len(results)} tests passed")
185+
print("=" * 70)
186+
187+
if all(results):
188+
print("\n✅ ALL TESTS PASSED - Fix is working correctly!")
189+
return 0
190+
else:
191+
print("\n❌ SOME TESTS FAILED - Fix may not be working")
192+
return 1
193+
194+
195+
if __name__ == "__main__":
196+
exit(main())

0 commit comments

Comments
 (0)