|
| 1 | +"""Test for Interrupt serialization fix (GitHub Issue #33556). |
| 2 | +
|
| 3 | +This test verifies that Interrupt objects are properly serialized and deserialized |
| 4 | +by the JsonPlusRedisSerializer, preventing the AttributeError that occurs when |
| 5 | +code tries to access the 'id' attribute on what it expects to be an Interrupt |
| 6 | +object but is actually a plain dictionary. |
| 7 | +
|
| 8 | +Issue: https://github.com/langchain-ai/langchain/issues/33556 |
| 9 | +""" |
| 10 | + |
| 11 | +import asyncio |
| 12 | +import json |
| 13 | +import uuid |
| 14 | +from typing import Any |
| 15 | + |
| 16 | +import pytest |
| 17 | +from langchain_core.runnables import RunnableConfig |
| 18 | +from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata |
| 19 | +from langgraph.types import Interrupt, interrupt |
| 20 | + |
| 21 | +from langgraph.checkpoint.redis import RedisSaver |
| 22 | +from langgraph.checkpoint.redis.aio import AsyncRedisSaver |
| 23 | +from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer |
| 24 | + |
| 25 | + |
| 26 | +class TestInterruptSerialization: |
| 27 | + """Test suite for Interrupt object serialization and deserialization.""" |
| 28 | + |
| 29 | + def test_interrupt_direct_serialization(self): |
| 30 | + """Test that Interrupt objects are properly serialized and deserialized.""" |
| 31 | + serializer = JsonPlusRedisSerializer() |
| 32 | + |
| 33 | + # Create an Interrupt object |
| 34 | + interrupt_obj = Interrupt( |
| 35 | + value={"tool_name": "external_action", "message": "Need approval"}, |
| 36 | + id="test-interrupt-123" |
| 37 | + ) |
| 38 | + |
| 39 | + # Test serialization/deserialization |
| 40 | + serialized = serializer.dumps(interrupt_obj) |
| 41 | + deserialized = serializer.loads(serialized) |
| 42 | + |
| 43 | + # Verify it's an Interrupt object with the correct attributes |
| 44 | + assert isinstance(deserialized, Interrupt), f"Expected Interrupt, got {type(deserialized)}" |
| 45 | + assert hasattr(deserialized, 'id'), "Deserialized object should have 'id' attribute" |
| 46 | + assert deserialized.id == "test-interrupt-123", f"ID mismatch: {deserialized.id}" |
| 47 | + assert deserialized.value == {"tool_name": "external_action", "message": "Need approval"} |
| 48 | + |
| 49 | + def test_interrupt_constructor_format(self): |
| 50 | + """Test that Interrupt objects are serialized in LangChain constructor format.""" |
| 51 | + serializer = JsonPlusRedisSerializer() |
| 52 | + |
| 53 | + interrupt_obj = Interrupt( |
| 54 | + value={"data": "test"}, |
| 55 | + id="constructor-test-id" |
| 56 | + ) |
| 57 | + |
| 58 | + serialized = serializer.dumps(interrupt_obj) |
| 59 | + |
| 60 | + # Parse the JSON to check the format |
| 61 | + parsed = json.loads(serialized) |
| 62 | + assert parsed.get("lc") == 2, "Should have lc=2 for constructor format" |
| 63 | + assert parsed.get("type") == "constructor", "Should have type=constructor" |
| 64 | + assert parsed.get("id") == ["langgraph", "types", "Interrupt"], "Should have correct id path" |
| 65 | + assert "kwargs" in parsed, "Should have kwargs field" |
| 66 | + assert parsed["kwargs"]["id"] == "constructor-test-id" |
| 67 | + |
| 68 | + def test_plain_dict_reconstruction(self): |
| 69 | + """Test that plain dicts with value/id keys are reconstructed as Interrupt objects.""" |
| 70 | + serializer = JsonPlusRedisSerializer() |
| 71 | + |
| 72 | + # This simulates what happens when Interrupt is stored as plain dict |
| 73 | + plain_dict_interrupt = {"value": {"data": "test"}, "id": "plain-id"} |
| 74 | + serialized = serializer.dumps(plain_dict_interrupt) |
| 75 | + deserialized = serializer.loads(serialized) |
| 76 | + |
| 77 | + # Should be reconstructed as an Interrupt |
| 78 | + assert isinstance(deserialized, Interrupt), f"Expected Interrupt, got {type(deserialized)}" |
| 79 | + assert hasattr(deserialized, 'id'), "Should have 'id' attribute" |
| 80 | + assert deserialized.id == "plain-id", f"ID should be preserved: {deserialized.id}" |
| 81 | + assert deserialized.value == {"data": "test"} |
| 82 | + |
| 83 | + def test_nested_interrupt_in_list(self): |
| 84 | + """Test Interrupt serialization in nested structures like pending_writes.""" |
| 85 | + serializer = JsonPlusRedisSerializer() |
| 86 | + |
| 87 | + # Simulate pending_writes structure |
| 88 | + interrupt_obj = Interrupt(value={"interrupt": "data"}, id="nested-id") |
| 89 | + nested_data = [ |
| 90 | + ("task1", interrupt_obj), |
| 91 | + ("task2", {"regular": "dict"}) |
| 92 | + ] |
| 93 | + |
| 94 | + serialized = serializer.dumps(nested_data) |
| 95 | + deserialized = serializer.loads(serialized) |
| 96 | + |
| 97 | + # Verify the Interrupt in the nested structure |
| 98 | + assert len(deserialized) == 2 |
| 99 | + task1_value = deserialized[0][1] |
| 100 | + task2_value = deserialized[1][1] |
| 101 | + |
| 102 | + assert isinstance(task1_value, Interrupt), "task1 should have Interrupt" |
| 103 | + assert task1_value.id == "nested-id" |
| 104 | + assert isinstance(task2_value, dict), "task2 should remain dict" |
| 105 | + |
| 106 | + def test_plain_dict_in_nested_structure(self): |
| 107 | + """Test that plain dicts with value/id in nested structures are reconstructed.""" |
| 108 | + serializer = JsonPlusRedisSerializer() |
| 109 | + |
| 110 | + # Simulate the problematic case from the issue |
| 111 | + nested_structure = [ |
| 112 | + ("task1", {"value": {"interrupt": "data"}, "id": "interrupt-1"}), |
| 113 | + ("task2", {"normal": "dict", "no": "conversion"}), |
| 114 | + ] |
| 115 | + |
| 116 | + serialized = serializer.dumps(nested_structure) |
| 117 | + deserialized = serializer.loads(serialized) |
| 118 | + |
| 119 | + task1_value = deserialized[0][1] |
| 120 | + task2_value = deserialized[1][1] |
| 121 | + |
| 122 | + # task1 should be reconstructed as Interrupt |
| 123 | + assert isinstance(task1_value, Interrupt), f"task1 should have Interrupt, got {type(task1_value)}" |
| 124 | + assert task1_value.id == "interrupt-1" |
| 125 | + # This is the line that would fail in the original bug |
| 126 | + interrupt_id = task1_value.id # Should not raise AttributeError |
| 127 | + assert interrupt_id == "interrupt-1" |
| 128 | + |
| 129 | + # task2 should remain a dict |
| 130 | + assert isinstance(task2_value, dict), f"task2 should remain dict, got {type(task2_value)}" |
| 131 | + |
| 132 | + def test_edge_cases_not_converted(self): |
| 133 | + """Test that dicts that shouldn't be converted to Interrupt remain as dicts.""" |
| 134 | + serializer = JsonPlusRedisSerializer() |
| 135 | + |
| 136 | + # Dict with non-string id - should not convert |
| 137 | + non_string_id = {"value": "test", "id": 123} |
| 138 | + result = serializer.loads(serializer.dumps(non_string_id)) |
| 139 | + assert isinstance(result, dict), "Should not convert when id is not string" |
| 140 | + |
| 141 | + # Dict with extra fields - should not convert |
| 142 | + extra_fields = {"value": "test", "id": "test-id", "extra": "field"} |
| 143 | + result = serializer.loads(serializer.dumps(extra_fields)) |
| 144 | + assert isinstance(result, dict), "Should not convert when extra fields present" |
| 145 | + |
| 146 | + # Dict with only value - should not convert |
| 147 | + only_value = {"value": "test"} |
| 148 | + result = serializer.loads(serializer.dumps(only_value)) |
| 149 | + assert isinstance(result, dict), "Should not convert with only value field" |
| 150 | + |
| 151 | + # Dict with only id - should not convert |
| 152 | + only_id = {"id": "test-id"} |
| 153 | + result = serializer.loads(serializer.dumps(only_id)) |
| 154 | + assert isinstance(result, dict), "Should not convert with only id field" |
| 155 | + |
| 156 | + def test_complex_interrupt_value(self): |
| 157 | + """Test Interrupt with complex nested value structures.""" |
| 158 | + serializer = JsonPlusRedisSerializer() |
| 159 | + |
| 160 | + complex_value = { |
| 161 | + "tool_name": "external_action", |
| 162 | + "tool_args": { |
| 163 | + "name": "Foo", |
| 164 | + "config": {"timeout": 30, "retries": 3}, |
| 165 | + "nested": {"deep": {"structure": ["a", "b", "c"]}} |
| 166 | + }, |
| 167 | + "metadata": {"timestamp": "2024-01-01", "user_id": "user123"} |
| 168 | + } |
| 169 | + |
| 170 | + interrupt_obj = Interrupt(value=complex_value, id="complex-id") |
| 171 | + |
| 172 | + serialized = serializer.dumps(interrupt_obj) |
| 173 | + deserialized = serializer.loads(serialized) |
| 174 | + |
| 175 | + assert isinstance(deserialized, Interrupt) |
| 176 | + assert deserialized.id == "complex-id" |
| 177 | + assert deserialized.value == complex_value |
| 178 | + assert deserialized.value["tool_args"]["nested"]["deep"]["structure"] == ["a", "b", "c"] |
| 179 | + |
| 180 | + |
| 181 | +@pytest.mark.asyncio |
| 182 | +class TestInterruptSerializationAsync: |
| 183 | + """Async tests for Interrupt serialization with Redis checkpointers.""" |
| 184 | + |
| 185 | + async def test_interrupt_in_checkpoint_async(self, redis_url: str): |
| 186 | + """Test that Interrupt objects in checkpoints are properly handled.""" |
| 187 | + async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: |
| 188 | + thread_id = f"test-interrupt-{uuid.uuid4()}" |
| 189 | + config = { |
| 190 | + "configurable": { |
| 191 | + "thread_id": thread_id, |
| 192 | + "checkpoint_ns": "", |
| 193 | + "checkpoint_id": str(uuid.uuid4()), |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + # Create an Interrupt object |
| 198 | + interrupt_obj = Interrupt( |
| 199 | + value={ |
| 200 | + "tool_name": "external_action", |
| 201 | + "tool_args": {"name": "TestArg"}, |
| 202 | + "message": "Need external system call", |
| 203 | + }, |
| 204 | + id="async-interrupt-id" |
| 205 | + ) |
| 206 | + |
| 207 | + # Create checkpoint with Interrupt in pending_writes |
| 208 | + checkpoint = { |
| 209 | + "v": 1, |
| 210 | + "ts": "2024-01-01T00:00:00+00:00", |
| 211 | + "id": config["configurable"]["checkpoint_id"], |
| 212 | + "channel_values": {"messages": ["test message"]}, |
| 213 | + "channel_versions": {}, |
| 214 | + "versions_seen": {}, |
| 215 | + "pending_writes": [ |
| 216 | + ("interrupt_task", interrupt_obj), |
| 217 | + ], |
| 218 | + } |
| 219 | + |
| 220 | + metadata = {"source": "test", "step": 1, "writes": {}} |
| 221 | + |
| 222 | + # Save the checkpoint |
| 223 | + await checkpointer.aput(config, checkpoint, metadata, {}) |
| 224 | + |
| 225 | + # Retrieve the checkpoint |
| 226 | + checkpoint_tuple = await checkpointer.aget_tuple(config) |
| 227 | + |
| 228 | + assert checkpoint_tuple is not None |
| 229 | + |
| 230 | + # Verify pending_writes contains an Interrupt object |
| 231 | + assert len(checkpoint_tuple.pending_writes) == 1 |
| 232 | + task_id, value = checkpoint_tuple.pending_writes[0] |
| 233 | + |
| 234 | + assert task_id == "interrupt_task" |
| 235 | + assert isinstance(value, Interrupt), f"Expected Interrupt, got {type(value)}" |
| 236 | + assert hasattr(value, 'id'), "Should have 'id' attribute" |
| 237 | + assert value.id == "async-interrupt-id" |
| 238 | + |
| 239 | + # This simulates the code that was failing in the issue |
| 240 | + # It should not raise AttributeError |
| 241 | + pending_interrupts = {} |
| 242 | + for task_id, val in checkpoint_tuple.pending_writes: |
| 243 | + if isinstance(val, Interrupt): |
| 244 | + pending_interrupts[task_id] = val.id |
| 245 | + |
| 246 | + assert pending_interrupts == {"interrupt_task": "async-interrupt-id"} |
| 247 | + |
| 248 | + async def test_multiple_interrupts_async(self, redis_url: str): |
| 249 | + """Test handling multiple Interrupt objects in a checkpoint.""" |
| 250 | + async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: |
| 251 | + thread_id = f"test-multi-interrupt-{uuid.uuid4()}" |
| 252 | + config = { |
| 253 | + "configurable": { |
| 254 | + "thread_id": thread_id, |
| 255 | + "checkpoint_ns": "", |
| 256 | + "checkpoint_id": str(uuid.uuid4()), |
| 257 | + } |
| 258 | + } |
| 259 | + |
| 260 | + # Create multiple Interrupts |
| 261 | + interrupts = [ |
| 262 | + ("task1", Interrupt(value={"action": "approve"}, id="interrupt-1")), |
| 263 | + ("task2", Interrupt(value={"action": "deny"}, id="interrupt-2")), |
| 264 | + ("task3", {"regular": "dict", "not": "interrupt"}), |
| 265 | + ("task4", Interrupt(value={"action": "retry"}, id="interrupt-3")), |
| 266 | + ] |
| 267 | + |
| 268 | + checkpoint = { |
| 269 | + "v": 1, |
| 270 | + "ts": "2024-01-01T00:00:00+00:00", |
| 271 | + "id": config["configurable"]["checkpoint_id"], |
| 272 | + "channel_values": {}, |
| 273 | + "channel_versions": {}, |
| 274 | + "versions_seen": {}, |
| 275 | + "pending_writes": interrupts, |
| 276 | + } |
| 277 | + |
| 278 | + metadata = {"source": "test", "step": 1} |
| 279 | + |
| 280 | + await checkpointer.aput(config, checkpoint, metadata, {}) |
| 281 | + checkpoint_tuple = await checkpointer.aget_tuple(config) |
| 282 | + |
| 283 | + assert checkpoint_tuple is not None |
| 284 | + assert len(checkpoint_tuple.pending_writes) == 4 |
| 285 | + |
| 286 | + # Verify each item |
| 287 | + for i, (task_id, value) in enumerate(checkpoint_tuple.pending_writes): |
| 288 | + if task_id in ["task1", "task2", "task4"]: |
| 289 | + assert isinstance(value, Interrupt), f"{task_id} should have Interrupt" |
| 290 | + assert hasattr(value, 'id') |
| 291 | + # Verify we can access the id without error |
| 292 | + _ = value.id |
| 293 | + elif task_id == "task3": |
| 294 | + assert isinstance(value, dict), "task3 should remain dict" |
| 295 | + |
| 296 | + |
| 297 | +class TestInterruptSerializationSync: |
| 298 | + """Sync tests for Interrupt serialization with Redis checkpointers.""" |
| 299 | + |
| 300 | + def test_interrupt_with_empty_value(self): |
| 301 | + """Test Interrupt with None or empty value.""" |
| 302 | + serializer = JsonPlusRedisSerializer() |
| 303 | + |
| 304 | + # Interrupt with None value |
| 305 | + interrupt_none = Interrupt(value=None, id="none-value-id") |
| 306 | + result = serializer.loads(serializer.dumps(interrupt_none)) |
| 307 | + assert isinstance(result, Interrupt) |
| 308 | + assert result.value is None |
| 309 | + assert result.id == "none-value-id" |
| 310 | + |
| 311 | + # Interrupt with empty dict value |
| 312 | + interrupt_empty = Interrupt(value={}, id="empty-value-id") |
| 313 | + result = serializer.loads(serializer.dumps(interrupt_empty)) |
| 314 | + assert isinstance(result, Interrupt) |
| 315 | + assert result.value == {} |
| 316 | + assert result.id == "empty-value-id" |
| 317 | + |
| 318 | + def test_backwards_compatibility(self): |
| 319 | + """Test that the fix doesn't break existing non-Interrupt data.""" |
| 320 | + serializer = JsonPlusRedisSerializer() |
| 321 | + |
| 322 | + # Various data types that should work as before |
| 323 | + test_cases = [ |
| 324 | + {"message": "regular dict", "type": "test"}, |
| 325 | + ["list", "of", "strings"], |
| 326 | + {"nested": {"structure": {"with": ["mixed", "types", 123]}}}, |
| 327 | + {"value": "has value key but not id"}, |
| 328 | + {"id": "has id key but not value"}, |
| 329 | + {"value": 123, "id": "non-string-value", "extra": "field"}, |
| 330 | + ] |
| 331 | + |
| 332 | + for original in test_cases: |
| 333 | + result = serializer.loads(serializer.dumps(original)) |
| 334 | + assert result == original, f"Data should be unchanged: {original}" |
0 commit comments