diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index 1bdbaf0..2c0f123 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -354,13 +354,42 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: CHECKPOINT_WRITE_PREFIX, ) - # Get the blob keys - blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*" - blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)] - - # Also get checkpoint write keys that should have the same TTL - write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*" - write_keys = [key.decode() for key in self._redis.keys(write_key_pattern)] + # Get the blob keys using search index instead of keys() + blob_query = FilterQuery( + filter_expression=( + Tag("thread_id") == to_storage_safe_id(doc_thread_id) + ) + & (Tag("checkpoint_ns") == to_storage_safe_str(doc_checkpoint_ns)), + return_fields=["key"], # Assuming the key field exists in the index + num_results=1000, + ) + blob_results = self.checkpoint_blobs_index.search(blob_query) + blob_keys = [ + f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{getattr(doc, 'channel', '')}:{getattr(doc, 'version', '')}" + for doc in blob_results.docs + ] + + # Get checkpoint write keys using search index + write_query = FilterQuery( + filter_expression=( + Tag("thread_id") == to_storage_safe_id(doc_thread_id) + ) + & (Tag("checkpoint_ns") == to_storage_safe_str(doc_checkpoint_ns)) + & (Tag("checkpoint_id") == to_storage_safe_id(doc_checkpoint_id)), + return_fields=["task_id", "idx"], + num_results=1000, + ) + write_results = self.checkpoint_writes_index.search(write_query) + write_keys = [ + BaseRedisSaver._make_redis_checkpoint_writes_key( + to_storage_safe_id(doc_thread_id), + to_storage_safe_str(doc_checkpoint_ns), + to_storage_safe_id(doc_checkpoint_id), + getattr(doc, "task_id", ""), + getattr(doc, "idx", 0), + ) + for doc in write_results.docs + ] # Apply TTL to checkpoint, blob keys, and write keys all_related_keys = blob_keys + write_keys @@ -489,12 +518,15 @@ def get_channel_values( blob_results = self.checkpoint_blobs_index.search(blob_query) if blob_results.docs: blob_doc = blob_results.docs[0] - blob_type = blob_doc.type + blob_type = getattr(blob_doc, "type", None) blob_data = getattr(blob_doc, "$.blob", None) - if blob_data and blob_type != "empty": + if blob_data and blob_type and blob_type != "empty": + # Ensure blob_data is bytes for deserialization + if isinstance(blob_data, str): + blob_data = blob_data.encode("utf-8") channel_values[channel] = self.serde.loads_typed( - (blob_type, blob_data) + (str(blob_type), blob_data) ) return channel_values diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 65d72ff..e157ee1 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -890,30 +890,38 @@ async def _aload_pending_writes( if checkpoint_id is None: return [] # Early return if no checkpoint_id - writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key( - to_storage_safe_id(thread_id), - to_storage_safe_str(checkpoint_ns), - to_storage_safe_id(checkpoint_id), - "*", - None, - ) - matching_keys = await self._redis.keys(pattern=writes_key) - # Use safely_decode to handle both string and bytes responses - decoded_keys = [safely_decode(key) for key in matching_keys] - parsed_keys = [ - BaseRedisSaver._parse_redis_checkpoint_writes_key(key) - for key in decoded_keys - ] - pending_writes = BaseRedisSaver._load_writes( - self.serde, - { - ( - parsed_key["task_id"], - parsed_key["idx"], - ): await self._redis.json().get(key) - for key, parsed_key in sorted( - zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"] - ) - }, + # Use search index instead of keys() to avoid CrossSlot errors + # Note: For checkpoint_ns, we use the raw value for tag searches + # because RediSearch may not handle sentinel values correctly in tag fields + writes_query = FilterQuery( + filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id)) + & (Tag("checkpoint_ns") == checkpoint_ns) + & (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)), + return_fields=["task_id", "idx", "channel", "type", "$.blob"], + num_results=1000, # Adjust as needed ) + + writes_results = await self.checkpoint_writes_index.search(writes_query) + + # Sort results by idx to maintain order + sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0)) + + # Build the writes dictionary + writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {} + for doc in sorted_writes: + task_id = str(getattr(doc, "task_id", "")) + idx = str(getattr(doc, "idx", 0)) + blob_data = getattr(doc, "$.blob", "") + # Ensure blob is bytes for deserialization + if isinstance(blob_data, str): + blob_data = blob_data.encode("utf-8") + writes_dict[(task_id, idx)] = { + "task_id": task_id, + "idx": idx, + "channel": str(getattr(doc, "channel", "")), + "type": str(getattr(doc, "type", "")), + "blob": blob_data, + } + + pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict) return pending_writes diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index f227a33..b8a7c69 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -16,6 +16,8 @@ ) from langgraph.checkpoint.serde.base import SerializerProtocol from langgraph.checkpoint.serde.types import ChannelProtocol +from redisvl.query import FilterQuery +from redisvl.query.filter import Tag from langgraph.checkpoint.redis.util import ( safely_decode, @@ -440,7 +442,7 @@ def put_writes( type_, blob = self.serde.dumps_typed(value) write_obj = { "thread_id": to_storage_safe_id(thread_id), - "checkpoint_ns": to_storage_safe_str(checkpoint_ns), + "checkpoint_ns": checkpoint_ns, # Don't use sentinel for tag fields in RediSearch "checkpoint_id": to_storage_safe_id(checkpoint_id), "task_id": task_id, "task_path": task_path, @@ -462,7 +464,7 @@ def put_writes( checkpoint_ns, checkpoint_id, task_id, - write_obj["idx"], # type: ignore[arg-type] + write_obj["idx"], ) # First check if key exists @@ -499,33 +501,40 @@ def _load_pending_writes( if checkpoint_id is None: return [] # Early return if no checkpoint_id - writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key( - to_storage_safe_id(thread_id), - to_storage_safe_str(checkpoint_ns), - to_storage_safe_id(checkpoint_id), - "*", - None, + # Use search index instead of keys() to avoid CrossSlot errors + # Note: For checkpoint_ns, we use the raw value for tag searches + # because RediSearch may not handle sentinel values correctly in tag fields + writes_query = FilterQuery( + filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id)) + & (Tag("checkpoint_ns") == checkpoint_ns) + & (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)), + return_fields=["task_id", "idx", "channel", "type", "$.blob"], + num_results=1000, # Adjust as needed ) - # Cast the result to List[bytes] to help type checker - matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment] - - # Use safely_decode to handle both string and bytes responses - decoded_keys = [safely_decode(key) for key in matching_keys] + writes_results = self.checkpoint_writes_index.search(writes_query) + + # Sort results by idx to maintain order + sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0)) + + # Build the writes dictionary + writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {} + for doc in sorted_writes: + task_id = str(getattr(doc, "task_id", "")) + idx = str(getattr(doc, "idx", 0)) + blob_data = getattr(doc, "$.blob", "") + # Ensure blob is bytes for deserialization + if isinstance(blob_data, str): + blob_data = blob_data.encode("utf-8") + writes_dict[(task_id, idx)] = { + "task_id": task_id, + "idx": idx, + "channel": str(getattr(doc, "channel", "")), + "type": str(getattr(doc, "type", "")), + "blob": blob_data, + } - parsed_keys = [ - BaseRedisSaver._parse_redis_checkpoint_writes_key(key) - for key in decoded_keys - ] - pending_writes = BaseRedisSaver._load_writes( - self.serde, - { - (parsed_key["task_id"], parsed_key["idx"]): self._redis.json().get(key) - for key, parsed_key in sorted( - zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"] - ) - }, - ) + pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict) return pending_writes @staticmethod diff --git a/tests/test_crossslot_integration.py b/tests/test_crossslot_integration.py new file mode 100644 index 0000000..33543a7 --- /dev/null +++ b/tests/test_crossslot_integration.py @@ -0,0 +1,170 @@ +"""Integration tests for CrossSlot error fix in checkpoint operations.""" + +import pytest +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + Checkpoint, + CheckpointMetadata, + create_checkpoint, + empty_checkpoint, +) + +from langgraph.checkpoint.redis import RedisSaver + + +def test_checkpoint_operations_no_crossslot_errors(redis_url: str) -> None: + """Test that checkpoint operations work without CrossSlot errors. + + This test verifies that the fix for using search indexes instead of keys() + works correctly in a real Redis environment. + """ + # Create a saver + saver = RedisSaver(redis_url) + saver.setup() + + # Create test data + thread_id = "test-thread-crossslot" + checkpoint_ns = "test-ns" + + # Create checkpoints with unique IDs + checkpoint1 = create_checkpoint(empty_checkpoint(), {}, 1) + checkpoint2 = create_checkpoint(checkpoint1, {"messages": ["hello"]}, 2) + checkpoint3 = create_checkpoint(checkpoint2, {"messages": ["hello", "world"]}, 3) + + # Create metadata + metadata1 = {"source": "input", "step": 1, "writes": {"task1": "value1"}} + metadata2 = {"source": "loop", "step": 2, "writes": {"task2": "value2"}} + metadata3 = {"source": "loop", "step": 3, "writes": {"task3": "value3"}} + + # Put checkpoints with writes + config1 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}} + config2 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}} + config3 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}} + + # Put checkpoints first to get configs with checkpoint_ids + saved_config1 = saver.put(config1, checkpoint1, metadata1, {}) + saved_config2 = saver.put(config2, checkpoint2, metadata2, {}) + saved_config3 = saver.put(config3, checkpoint3, metadata3, {}) + + # Add some pending writes using saved configs + saver.put_writes( + saved_config1, + [ + ("channel1", {"value": "data1"}), + ("channel2", {"value": "data2"}), + ], + "task-1", + ) + + # Now test operations that previously used keys() and would fail in cluster mode + + # Test 1: Load pending writes (uses _load_pending_writes) + # This should work without CrossSlot errors + tuple1 = saver.get_tuple(saved_config1) + assert tuple1 is not None + # Verify pending writes were loaded + assert len(tuple1.pending_writes) == 2 + pending_channels = [w[1] for w in tuple1.pending_writes] + assert "channel1" in pending_channels + assert "channel2" in pending_channels + + # Test 2: Get tuple with TTL (uses get_tuple which searches for blob and write keys) + saver_with_ttl = RedisSaver(redis_url, ttl={"checkpoint": 3600}) + saver_with_ttl.setup() + + # Put a checkpoint with TTL + config_ttl = { + "configurable": {"thread_id": "ttl-thread", "checkpoint_ns": "ttl-ns"} + } + saver_with_ttl.put(config_ttl, checkpoint1, metadata1, {}) + + # Get the checkpoint - this triggers TTL application which uses key searches + tuple_ttl = saver_with_ttl.get_tuple(config_ttl) + assert tuple_ttl is not None + + # Test 3: List checkpoints - this should work without CrossSlot errors + # List returns only the latest checkpoint by default + checkpoints = list(saver.list(config1)) + assert len(checkpoints) >= 1 + + # The latest checkpoint should have the pending writes from checkpoint1 + latest_checkpoint = checkpoints[0] + assert len(latest_checkpoint.pending_writes) == 2 + + # The important part is that all these operations work without CrossSlot errors + # In a Redis cluster, the old keys() based approach would have failed by now + + +def test_subgraph_checkpoint_operations(redis_url: str) -> None: + """Test checkpoint operations with subgraphs work without CrossSlot errors.""" + saver = RedisSaver(redis_url) + saver.setup() + + # Create nested namespace checkpoints + thread_id = "test-thread-subgraph" + + # Parent checkpoint + parent_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": "", + } + } + parent_checkpoint = empty_checkpoint() + parent_metadata = {"source": "input", "step": 1} + + # Child checkpoint in subgraph + child_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": "subgraph1", + } + } + child_checkpoint = create_checkpoint(parent_checkpoint, {"subgraph": "data"}, 1) + child_metadata = {"source": "loop", "step": 1} + + # Grandchild checkpoint in nested subgraph + grandchild_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": "subgraph1:subgraph2", + } + } + grandchild_checkpoint = create_checkpoint(child_checkpoint, {"nested": "data"}, 2) + grandchild_metadata = {"source": "loop", "step": 2} + + # Put all checkpoints first to get saved configs + saved_parent_config = saver.put( + parent_config, parent_checkpoint, parent_metadata, {} + ) + saved_child_config = saver.put(child_config, child_checkpoint, child_metadata, {}) + saved_grandchild_config = saver.put( + grandchild_config, grandchild_checkpoint, grandchild_metadata, {} + ) + + # Put checkpoints with writes using saved configs + saver.put_writes( + saved_parent_config, [("parent_channel", {"parent": "data"})], "parent-task" + ) + saver.put_writes( + saved_child_config, [("child_channel", {"child": "data"})], "child-task" + ) + saver.put_writes( + saved_grandchild_config, + [("grandchild_channel", {"grandchild": "data"})], + "grandchild-task", + ) + + # Test loading checkpoints with pending writes from different namespaces + parent_tuple = saver.get_tuple(parent_config) + assert parent_tuple is not None + + child_tuple = saver.get_tuple(child_config) + assert child_tuple is not None + + grandchild_tuple = saver.get_tuple(grandchild_config) + assert grandchild_tuple is not None + + # List all checkpoints - should work without CrossSlot errors + all_checkpoints = list(saver.list({"configurable": {"thread_id": thread_id}})) + assert len(all_checkpoints) >= 3