From 937da834899abaf277bcd41bed0a043277bea2cf Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 6 Mar 2025 14:49:16 -0800 Subject: [PATCH 1/2] Clean up some key and UUID->str conversion handling --- langgraph/checkpoint/redis/__init__.py | 6 ++- langgraph/checkpoint/redis/aio.py | 29 +++++++++---- langgraph/checkpoint/redis/base.py | 10 ++--- langgraph/checkpoint/redis/util.py | 60 +++++++------------------- 4 files changed, 44 insertions(+), 61 deletions(-) diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index 44825ca..7b9c464 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -214,8 +214,10 @@ def put( thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") - checkpoint_id = checkpoint_id = configurable.pop( - "checkpoint_id", configurable.pop("thread_ts", "") + thread_ts = configurable.pop("thread_ts", "") + checkpoint_id = ( + configurable.pop("checkpoint_id", configurable.pop("thread_ts", "")) + or thread_ts ) # For values we store in Redis, we need to convert empty strings to the diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 8df73dc..6ce7b48 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -7,7 +7,6 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from functools import partial -from sys import thread_info from types import TracebackType from typing import Any, List, Optional, Sequence, Tuple, Type, cast @@ -375,10 +374,20 @@ async def aput( ) -> RunnableConfig: """Store a checkpoint to Redis.""" configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") thread_ts = configurable.pop("thread_ts", "") - checkpoint_id = configurable.pop("checkpoint_id", thread_ts) or thread_ts + checkpoint_id = ( + configurable.pop("checkpoint_id", configurable.pop("thread_ts", "")) + or thread_ts + ) + + # For values we store in Redis, we need to convert empty strings to the + # sentinel value. + storage_safe_thread_id = to_storage_safe_id(thread_id) + storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) + storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id) copy = checkpoint.copy() next_config = { @@ -391,10 +400,10 @@ async def aput( # Store checkpoint data checkpoint_data = { - "thread_id": thread_id, - "checkpoint_ns": to_storage_safe_str(checkpoint_ns), - "checkpoint_id": to_storage_safe_id(checkpoint_id), - "parent_checkpoint_id": to_storage_safe_id(checkpoint_id), + "thread_id": storage_safe_thread_id, + "checkpoint_ns": storage_safe_checkpoint_ns, + "checkpoint_id": storage_safe_checkpoint_id, + "parent_checkpoint_id": storage_safe_checkpoint_id, "checkpoint": self._dump_checkpoint(copy), "metadata": self._dump_metadata(metadata), } @@ -408,15 +417,17 @@ async def aput( [checkpoint_data], keys=[ BaseRedisSaver._make_redis_checkpoint_key( - thread_id, checkpoint_ns, checkpoint_id + storage_safe_thread_id, + storage_safe_checkpoint_ns, + storage_safe_checkpoint_id, ) ], ) # Store blob values blobs = self._dump_blobs( - thread_id, - checkpoint_ns, + storage_safe_thread_id, + storage_safe_checkpoint_ns, copy.get("channel_values", {}), new_versions, ) diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index 0be078e..f00c5b3 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -457,9 +457,9 @@ def _make_redis_checkpoint_key( return REDIS_KEY_SEPARATOR.join( [ CHECKPOINT_PREFIX, - to_storage_safe_id(thread_id), + str(to_storage_safe_id(thread_id)), to_storage_safe_str(checkpoint_ns), - to_storage_safe_id(checkpoint_id), + str(to_storage_safe_id(checkpoint_id)), ] ) @@ -470,7 +470,7 @@ def _make_redis_checkpoint_blob_key( return REDIS_KEY_SEPARATOR.join( [ CHECKPOINT_BLOB_PREFIX, - to_storage_safe_str(thread_id), + str(to_storage_safe_id(thread_id)), to_storage_safe_str(checkpoint_ns), channel, version, @@ -485,9 +485,9 @@ def _make_redis_checkpoint_writes_key( task_id: str, idx: Optional[int], ) -> str: - storage_safe_thread_id = to_storage_safe_str(thread_id) + storage_safe_thread_id = str(to_storage_safe_id(thread_id)) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) - storage_safe_checkpoint_id = to_storage_safe_str(checkpoint_id) + storage_safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id)) if idx is None: return REDIS_KEY_SEPARATOR.join( diff --git a/langgraph/checkpoint/redis/util.py b/langgraph/checkpoint/redis/util.py index d27fdeb..2489b4a 100644 --- a/langgraph/checkpoint/redis/util.py +++ b/langgraph/checkpoint/redis/util.py @@ -1,5 +1,3 @@ -from typing import Any, Callable, Optional, TypeVar, Union - """ RediSearch versions below 2.10 don't support indexing and querying empty strings, so we use a sentinel value to represent empty strings. @@ -8,14 +6,17 @@ sentinel values are from the first run of the graph, so this should generally be correct. """ + EMPTY_STRING_SENTINEL = "__empty__" EMPTY_ID_SENTINEL = "00000000-0000-0000-0000-000000000000" def to_storage_safe_str(value: str) -> str: """ - Convert any empty string to an empty string sentinel if found, - otherwise return the value unchanged. + Prepare a value for storage in Redis as a string. + + Convert an empty string to a sentinel value, otherwise return the + value as a string. Args: value (str): The value to convert. @@ -26,13 +27,13 @@ def to_storage_safe_str(value: str) -> str: if value == "": return EMPTY_STRING_SENTINEL else: - return value + return str(value) def from_storage_safe_str(value: str) -> str: """ - Convert a value from an empty string sentinel to an empty string - if found, otherwise return the value unchanged. + Convert a value from a sentinel value to an empty string if present, + otherwise return the value unchanged. Args: value (str): The value to convert. @@ -48,8 +49,10 @@ def from_storage_safe_str(value: str) -> str: def to_storage_safe_id(value: str) -> str: """ - Convert any empty ID string to an empty ID sentinel if found, - otherwise return the value unchanged. + Prepare a value for storage in Redis as an ID. + + Convert an empty string to a sentinel value for empty ID strings, otherwise + return the value as a string. Args: value (str): The value to convert. @@ -60,13 +63,13 @@ def to_storage_safe_id(value: str) -> str: if value == "": return EMPTY_ID_SENTINEL else: - return value + return str(value) def from_storage_safe_id(value: str) -> str: """ - Convert a value from an empty ID sentinel to an empty ID - if found, otherwise return the value unchanged. + Convert a value from a sentinel value for empty ID strings to an empty + ID string if present, otherwise return the value unchanged. Args: value (str): The value to convert. @@ -78,36 +81,3 @@ def from_storage_safe_id(value: str) -> str: return "" else: return value - - -def storage_safe_get( - doc: dict[str, Any], key: str, default: Any = None -) -> Optional[Any]: - """ - Get a value from a Redis document or dictionary, using a sentinel - value to represent empty strings. - - If the sentinel value is found, it is converted back to an empty string. - - Args: - doc (dict[str, Any]): The document to get the value from. - key (str): The key to get the value from. - default (Any): The default value to return if the key is not found. - Returns: - Optional[Any]: None if the key is not found, or else the value from - the document or dictionary, with empty strings converted - to the empty string sentinel and the sentinel converted - back to an empty string. - """ - try: - # NOTE: The Document class that comes back from `search()` support - # [key] access but not `get()` for some reason, so we use direct - # key access with an exception guard. - value = doc[key] - except KeyError: - value = None - - if value is None: - return default - - return to_storage_safe_str(value) From b7414979d86e37c3bde589a02de8634a7c8882ea Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 6 Mar 2025 14:54:36 -0800 Subject: [PATCH 2/2] Silence mypy error --- langgraph/checkpoint/redis/aio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 6ce7b48..90aacde 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -411,7 +411,7 @@ async def aput( # store at top-level for filters in list() if all(key in metadata for key in ["source", "step"]): checkpoint_data["source"] = metadata["source"] - checkpoint_data["step"] = metadata["step"] + checkpoint_data["step"] = metadata["step"] # type: ignore await self.checkpoints_index.load( [checkpoint_data],