Skip to content

Commit 77054cb

Browse files
committed
fix: JsonPlusRedisSerializer
1 parent 62fd939 commit 77054cb

File tree

2 files changed

+365
-0
lines changed

2 files changed

+365
-0
lines changed

langgraph/checkpoint/redis/jsonplus_redis.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ class JsonPlusRedisSerializer(JsonPlusSerializer):
4040

4141
def dumps(self, obj: Any) -> bytes:
4242
"""Use orjson for simple objects, fallback to parent for complex objects."""
43+
try:
44+
# Check if this is an Interrupt object that needs special handling
45+
from langgraph.types import Interrupt
46+
if isinstance(obj, Interrupt):
47+
# Serialize Interrupt as a constructor format for proper deserialization
48+
return super().dumps(obj)
49+
except ImportError:
50+
pass
51+
4352
try:
4453
# Fast path: Use orjson for JSON-serializable objects
4554
return orjson.dumps(obj)
@@ -66,6 +75,10 @@ def _revive_if_needed(self, obj: Any) -> Any:
6675
reconstructed. Without this, messages would remain as dictionaries with
6776
'lc', 'type', and 'constructor' fields, causing errors when the application
6877
expects actual message objects with 'role' and 'content' attributes.
78+
79+
This also handles Interrupt objects that may be stored as plain dictionaries
80+
with 'value' and 'id' keys, reconstructing them as proper Interrupt instances
81+
to prevent AttributeError when accessing the 'id' attribute.
6982
7083
Args:
7184
obj: The object to potentially revive, which may be a dict, list, or primitive.
@@ -80,6 +93,24 @@ def _revive_if_needed(self, obj: Any) -> Any:
8093
# This converts {'lc': 1, 'type': 'constructor', ...} back to
8194
# the actual LangChain object (e.g., HumanMessage, AIMessage)
8295
return self._reviver(obj)
96+
97+
# Check if this looks like an Interrupt object stored as a plain dict
98+
# Interrupt objects have 'value' and 'id' keys, and possibly nothing else
99+
# We need to be careful not to accidentally convert other dicts
100+
if (
101+
"value" in obj
102+
and "id" in obj
103+
and len(obj) == 2
104+
and isinstance(obj.get("id"), str)
105+
):
106+
# Try to reconstruct as an Interrupt object
107+
try:
108+
from langgraph.types import Interrupt
109+
return Interrupt(value=obj["value"], id=obj["id"])
110+
except (ImportError, TypeError, ValueError):
111+
# If we can't import or construct Interrupt, fall through
112+
pass
113+
83114
# Recursively process nested dicts
84115
return {k: self._revive_if_needed(v) for k, v in obj.items()}
85116
elif isinstance(obj, list):
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
"""Test for Interrupt serialization fix (GitHub Issue #33556).
2+
3+
This test verifies that Interrupt objects are properly serialized and deserialized
4+
by the JsonPlusRedisSerializer, preventing the AttributeError that occurs when
5+
code tries to access the 'id' attribute on what it expects to be an Interrupt
6+
object but is actually a plain dictionary.
7+
8+
Issue: https://github.com/langchain-ai/langchain/issues/33556
9+
"""
10+
11+
import asyncio
12+
import json
13+
import uuid
14+
from typing import Any
15+
16+
import pytest
17+
from langchain_core.runnables import RunnableConfig
18+
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata
19+
from langgraph.types import Interrupt, interrupt
20+
21+
from langgraph.checkpoint.redis import RedisSaver
22+
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
23+
from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer
24+
25+
26+
class TestInterruptSerialization:
27+
"""Test suite for Interrupt object serialization and deserialization."""
28+
29+
def test_interrupt_direct_serialization(self):
30+
"""Test that Interrupt objects are properly serialized and deserialized."""
31+
serializer = JsonPlusRedisSerializer()
32+
33+
# Create an Interrupt object
34+
interrupt_obj = Interrupt(
35+
value={"tool_name": "external_action", "message": "Need approval"},
36+
id="test-interrupt-123"
37+
)
38+
39+
# Test serialization/deserialization
40+
serialized = serializer.dumps(interrupt_obj)
41+
deserialized = serializer.loads(serialized)
42+
43+
# Verify it's an Interrupt object with the correct attributes
44+
assert isinstance(deserialized, Interrupt), f"Expected Interrupt, got {type(deserialized)}"
45+
assert hasattr(deserialized, 'id'), "Deserialized object should have 'id' attribute"
46+
assert deserialized.id == "test-interrupt-123", f"ID mismatch: {deserialized.id}"
47+
assert deserialized.value == {"tool_name": "external_action", "message": "Need approval"}
48+
49+
def test_interrupt_constructor_format(self):
50+
"""Test that Interrupt objects are serialized in LangChain constructor format."""
51+
serializer = JsonPlusRedisSerializer()
52+
53+
interrupt_obj = Interrupt(
54+
value={"data": "test"},
55+
id="constructor-test-id"
56+
)
57+
58+
serialized = serializer.dumps(interrupt_obj)
59+
60+
# Parse the JSON to check the format
61+
parsed = json.loads(serialized)
62+
assert parsed.get("lc") == 2, "Should have lc=2 for constructor format"
63+
assert parsed.get("type") == "constructor", "Should have type=constructor"
64+
assert parsed.get("id") == ["langgraph", "types", "Interrupt"], "Should have correct id path"
65+
assert "kwargs" in parsed, "Should have kwargs field"
66+
assert parsed["kwargs"]["id"] == "constructor-test-id"
67+
68+
def test_plain_dict_reconstruction(self):
69+
"""Test that plain dicts with value/id keys are reconstructed as Interrupt objects."""
70+
serializer = JsonPlusRedisSerializer()
71+
72+
# This simulates what happens when Interrupt is stored as plain dict
73+
plain_dict_interrupt = {"value": {"data": "test"}, "id": "plain-id"}
74+
serialized = serializer.dumps(plain_dict_interrupt)
75+
deserialized = serializer.loads(serialized)
76+
77+
# Should be reconstructed as an Interrupt
78+
assert isinstance(deserialized, Interrupt), f"Expected Interrupt, got {type(deserialized)}"
79+
assert hasattr(deserialized, 'id'), "Should have 'id' attribute"
80+
assert deserialized.id == "plain-id", f"ID should be preserved: {deserialized.id}"
81+
assert deserialized.value == {"data": "test"}
82+
83+
def test_nested_interrupt_in_list(self):
84+
"""Test Interrupt serialization in nested structures like pending_writes."""
85+
serializer = JsonPlusRedisSerializer()
86+
87+
# Simulate pending_writes structure
88+
interrupt_obj = Interrupt(value={"interrupt": "data"}, id="nested-id")
89+
nested_data = [
90+
("task1", interrupt_obj),
91+
("task2", {"regular": "dict"})
92+
]
93+
94+
serialized = serializer.dumps(nested_data)
95+
deserialized = serializer.loads(serialized)
96+
97+
# Verify the Interrupt in the nested structure
98+
assert len(deserialized) == 2
99+
task1_value = deserialized[0][1]
100+
task2_value = deserialized[1][1]
101+
102+
assert isinstance(task1_value, Interrupt), "task1 should have Interrupt"
103+
assert task1_value.id == "nested-id"
104+
assert isinstance(task2_value, dict), "task2 should remain dict"
105+
106+
def test_plain_dict_in_nested_structure(self):
107+
"""Test that plain dicts with value/id in nested structures are reconstructed."""
108+
serializer = JsonPlusRedisSerializer()
109+
110+
# Simulate the problematic case from the issue
111+
nested_structure = [
112+
("task1", {"value": {"interrupt": "data"}, "id": "interrupt-1"}),
113+
("task2", {"normal": "dict", "no": "conversion"}),
114+
]
115+
116+
serialized = serializer.dumps(nested_structure)
117+
deserialized = serializer.loads(serialized)
118+
119+
task1_value = deserialized[0][1]
120+
task2_value = deserialized[1][1]
121+
122+
# task1 should be reconstructed as Interrupt
123+
assert isinstance(task1_value, Interrupt), f"task1 should have Interrupt, got {type(task1_value)}"
124+
assert task1_value.id == "interrupt-1"
125+
# This is the line that would fail in the original bug
126+
interrupt_id = task1_value.id # Should not raise AttributeError
127+
assert interrupt_id == "interrupt-1"
128+
129+
# task2 should remain a dict
130+
assert isinstance(task2_value, dict), f"task2 should remain dict, got {type(task2_value)}"
131+
132+
def test_edge_cases_not_converted(self):
133+
"""Test that dicts that shouldn't be converted to Interrupt remain as dicts."""
134+
serializer = JsonPlusRedisSerializer()
135+
136+
# Dict with non-string id - should not convert
137+
non_string_id = {"value": "test", "id": 123}
138+
result = serializer.loads(serializer.dumps(non_string_id))
139+
assert isinstance(result, dict), "Should not convert when id is not string"
140+
141+
# Dict with extra fields - should not convert
142+
extra_fields = {"value": "test", "id": "test-id", "extra": "field"}
143+
result = serializer.loads(serializer.dumps(extra_fields))
144+
assert isinstance(result, dict), "Should not convert when extra fields present"
145+
146+
# Dict with only value - should not convert
147+
only_value = {"value": "test"}
148+
result = serializer.loads(serializer.dumps(only_value))
149+
assert isinstance(result, dict), "Should not convert with only value field"
150+
151+
# Dict with only id - should not convert
152+
only_id = {"id": "test-id"}
153+
result = serializer.loads(serializer.dumps(only_id))
154+
assert isinstance(result, dict), "Should not convert with only id field"
155+
156+
def test_complex_interrupt_value(self):
157+
"""Test Interrupt with complex nested value structures."""
158+
serializer = JsonPlusRedisSerializer()
159+
160+
complex_value = {
161+
"tool_name": "external_action",
162+
"tool_args": {
163+
"name": "Foo",
164+
"config": {"timeout": 30, "retries": 3},
165+
"nested": {"deep": {"structure": ["a", "b", "c"]}}
166+
},
167+
"metadata": {"timestamp": "2024-01-01", "user_id": "user123"}
168+
}
169+
170+
interrupt_obj = Interrupt(value=complex_value, id="complex-id")
171+
172+
serialized = serializer.dumps(interrupt_obj)
173+
deserialized = serializer.loads(serialized)
174+
175+
assert isinstance(deserialized, Interrupt)
176+
assert deserialized.id == "complex-id"
177+
assert deserialized.value == complex_value
178+
assert deserialized.value["tool_args"]["nested"]["deep"]["structure"] == ["a", "b", "c"]
179+
180+
181+
@pytest.mark.asyncio
182+
class TestInterruptSerializationAsync:
183+
"""Async tests for Interrupt serialization with Redis checkpointers."""
184+
185+
async def test_interrupt_in_checkpoint_async(self, redis_url: str):
186+
"""Test that Interrupt objects in checkpoints are properly handled."""
187+
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
188+
thread_id = f"test-interrupt-{uuid.uuid4()}"
189+
config = {
190+
"configurable": {
191+
"thread_id": thread_id,
192+
"checkpoint_ns": "",
193+
"checkpoint_id": str(uuid.uuid4()),
194+
}
195+
}
196+
197+
# Create an Interrupt object
198+
interrupt_obj = Interrupt(
199+
value={
200+
"tool_name": "external_action",
201+
"tool_args": {"name": "TestArg"},
202+
"message": "Need external system call",
203+
},
204+
id="async-interrupt-id"
205+
)
206+
207+
# Create checkpoint with Interrupt in pending_writes
208+
checkpoint = {
209+
"v": 1,
210+
"ts": "2024-01-01T00:00:00+00:00",
211+
"id": config["configurable"]["checkpoint_id"],
212+
"channel_values": {"messages": ["test message"]},
213+
"channel_versions": {},
214+
"versions_seen": {},
215+
"pending_writes": [
216+
("interrupt_task", interrupt_obj),
217+
],
218+
}
219+
220+
metadata = {"source": "test", "step": 1, "writes": {}}
221+
222+
# Save the checkpoint
223+
await checkpointer.aput(config, checkpoint, metadata, {})
224+
225+
# Retrieve the checkpoint
226+
checkpoint_tuple = await checkpointer.aget_tuple(config)
227+
228+
assert checkpoint_tuple is not None
229+
230+
# Verify pending_writes contains an Interrupt object
231+
assert len(checkpoint_tuple.pending_writes) == 1
232+
task_id, value = checkpoint_tuple.pending_writes[0]
233+
234+
assert task_id == "interrupt_task"
235+
assert isinstance(value, Interrupt), f"Expected Interrupt, got {type(value)}"
236+
assert hasattr(value, 'id'), "Should have 'id' attribute"
237+
assert value.id == "async-interrupt-id"
238+
239+
# This simulates the code that was failing in the issue
240+
# It should not raise AttributeError
241+
pending_interrupts = {}
242+
for task_id, val in checkpoint_tuple.pending_writes:
243+
if isinstance(val, Interrupt):
244+
pending_interrupts[task_id] = val.id
245+
246+
assert pending_interrupts == {"interrupt_task": "async-interrupt-id"}
247+
248+
async def test_multiple_interrupts_async(self, redis_url: str):
249+
"""Test handling multiple Interrupt objects in a checkpoint."""
250+
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
251+
thread_id = f"test-multi-interrupt-{uuid.uuid4()}"
252+
config = {
253+
"configurable": {
254+
"thread_id": thread_id,
255+
"checkpoint_ns": "",
256+
"checkpoint_id": str(uuid.uuid4()),
257+
}
258+
}
259+
260+
# Create multiple Interrupts
261+
interrupts = [
262+
("task1", Interrupt(value={"action": "approve"}, id="interrupt-1")),
263+
("task2", Interrupt(value={"action": "deny"}, id="interrupt-2")),
264+
("task3", {"regular": "dict", "not": "interrupt"}),
265+
("task4", Interrupt(value={"action": "retry"}, id="interrupt-3")),
266+
]
267+
268+
checkpoint = {
269+
"v": 1,
270+
"ts": "2024-01-01T00:00:00+00:00",
271+
"id": config["configurable"]["checkpoint_id"],
272+
"channel_values": {},
273+
"channel_versions": {},
274+
"versions_seen": {},
275+
"pending_writes": interrupts,
276+
}
277+
278+
metadata = {"source": "test", "step": 1}
279+
280+
await checkpointer.aput(config, checkpoint, metadata, {})
281+
checkpoint_tuple = await checkpointer.aget_tuple(config)
282+
283+
assert checkpoint_tuple is not None
284+
assert len(checkpoint_tuple.pending_writes) == 4
285+
286+
# Verify each item
287+
for i, (task_id, value) in enumerate(checkpoint_tuple.pending_writes):
288+
if task_id in ["task1", "task2", "task4"]:
289+
assert isinstance(value, Interrupt), f"{task_id} should have Interrupt"
290+
assert hasattr(value, 'id')
291+
# Verify we can access the id without error
292+
_ = value.id
293+
elif task_id == "task3":
294+
assert isinstance(value, dict), "task3 should remain dict"
295+
296+
297+
class TestInterruptSerializationSync:
298+
"""Sync tests for Interrupt serialization with Redis checkpointers."""
299+
300+
def test_interrupt_with_empty_value(self):
301+
"""Test Interrupt with None or empty value."""
302+
serializer = JsonPlusRedisSerializer()
303+
304+
# Interrupt with None value
305+
interrupt_none = Interrupt(value=None, id="none-value-id")
306+
result = serializer.loads(serializer.dumps(interrupt_none))
307+
assert isinstance(result, Interrupt)
308+
assert result.value is None
309+
assert result.id == "none-value-id"
310+
311+
# Interrupt with empty dict value
312+
interrupt_empty = Interrupt(value={}, id="empty-value-id")
313+
result = serializer.loads(serializer.dumps(interrupt_empty))
314+
assert isinstance(result, Interrupt)
315+
assert result.value == {}
316+
assert result.id == "empty-value-id"
317+
318+
def test_backwards_compatibility(self):
319+
"""Test that the fix doesn't break existing non-Interrupt data."""
320+
serializer = JsonPlusRedisSerializer()
321+
322+
# Various data types that should work as before
323+
test_cases = [
324+
{"message": "regular dict", "type": "test"},
325+
["list", "of", "strings"],
326+
{"nested": {"structure": {"with": ["mixed", "types", 123]}}},
327+
{"value": "has value key but not id"},
328+
{"id": "has id key but not value"},
329+
{"value": 123, "id": "non-string-value", "extra": "field"},
330+
]
331+
332+
for original in test_cases:
333+
result = serializer.loads(serializer.dumps(original))
334+
assert result == original, f"Data should be unchanged: {original}"

0 commit comments

Comments
 (0)