22
33import json
44from contextlib import contextmanager
5- from typing import Any , Dict , Iterator , List , Optional , Tuple , cast
5+ import logging
6+ from typing import Any , Dict , Iterator , List , Optional , Tuple , Union , cast
67
78from langchain_core .runnables import RunnableConfig
89from langgraph .checkpoint .base import (
1415)
1516from langgraph .constants import TASKS
1617from redis import Redis
18+ from redis .cluster import RedisCluster
1719from redisvl .index import SearchIndex
1820from redisvl .query import FilterQuery
1921from redisvl .query .filter import Num , Tag
3234)
3335from langgraph .checkpoint .redis .version import __lib_name__ , __version__
3436
37+ logger = logging .getLogger (__name__ )
3538
36- class RedisSaver (BaseRedisSaver [Redis , SearchIndex ]):
39+
40+ class RedisSaver (BaseRedisSaver [Union [Redis , RedisCluster ], SearchIndex ]):
3741 """Standard Redis implementation for checkpoint saving."""
3842
43+ _redis : Union [Redis , RedisCluster ] # Support both standalone and cluster clients
44+ # Whether to assume the Redis server is a cluster; None triggers auto-detection
45+ cluster_mode : Optional [bool ] = None
46+
3947 def __init__ (
4048 self ,
4149 redis_url : Optional [str ] = None ,
4250 * ,
43- redis_client : Optional [Redis ] = None ,
51+ redis_client : Optional [Union [ Redis , RedisCluster ] ] = None ,
4452 connection_args : Optional [Dict [str , Any ]] = None ,
4553 ttl : Optional [Dict [str , Any ]] = None ,
4654 ) -> None :
@@ -54,7 +62,7 @@ def __init__(
5462 def configure_client (
5563 self ,
5664 redis_url : Optional [str ] = None ,
57- redis_client : Optional [Redis ] = None ,
65+ redis_client : Optional [Union [ Redis , RedisCluster ] ] = None ,
5866 connection_args : Optional [Dict [str , Any ]] = None ,
5967 ) -> None :
6068 """Configure the Redis client."""
@@ -74,6 +82,27 @@ def create_indexes(self) -> None:
7482 self .SCHEMAS [2 ], redis_client = self ._redis
7583 )
7684
85+ def setup (self ) -> None :
86+ """Initialize the indices in Redis and detect cluster mode."""
87+ self ._detect_cluster_mode ()
88+ super ().setup ()
89+
90+ def _detect_cluster_mode (self ) -> None :
91+ """Detect if the Redis client is a cluster client by inspecting its class."""
92+ if self .cluster_mode is not None :
93+ logger .info (
94+ f"Redis cluster_mode explicitly set to { self .cluster_mode } , skipping detection."
95+ )
96+ return
97+
98+ # Determine cluster mode based on client class
99+ if isinstance (self ._redis , RedisCluster ):
100+ logger .info ("Redis client is a cluster client" )
101+ self .cluster_mode = True
102+ else :
103+ logger .info ("Redis client is a standalone client" )
104+ self .cluster_mode = False
105+
77106 def list (
78107 self ,
79108 config : Optional [RunnableConfig ],
@@ -458,7 +487,7 @@ def from_conn_string(
458487 cls ,
459488 redis_url : Optional [str ] = None ,
460489 * ,
461- redis_client : Optional [Redis ] = None ,
490+ redis_client : Optional [Union [ Redis , RedisCluster ] ] = None ,
462491 connection_args : Optional [Dict [str , Any ]] = None ,
463492 ttl : Optional [Dict [str , Any ]] = None ,
464493 ) -> Iterator [RedisSaver ]:
@@ -592,8 +621,8 @@ def delete_thread(self, thread_id: str) -> None:
592621
593622 checkpoint_results = self .checkpoints_index .search (checkpoint_query )
594623
595- # Delete all checkpoint-related keys
596- pipeline = self . _redis . pipeline ()
624+ # Collect all keys to delete
625+ keys_to_delete = []
597626
598627 for doc in checkpoint_results .docs :
599628 checkpoint_ns = getattr (doc , "checkpoint_ns" , "" )
@@ -603,7 +632,7 @@ def delete_thread(self, thread_id: str) -> None:
603632 checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
604633 storage_safe_thread_id , checkpoint_ns , checkpoint_id
605634 )
606- pipeline . delete (checkpoint_key )
635+ keys_to_delete . append (checkpoint_key )
607636
608637 # Delete all blobs for this thread
609638 blob_query = FilterQuery (
@@ -622,7 +651,7 @@ def delete_thread(self, thread_id: str) -> None:
622651 blob_key = BaseRedisSaver ._make_redis_checkpoint_blob_key (
623652 storage_safe_thread_id , checkpoint_ns , channel , version
624653 )
625- pipeline . delete (blob_key )
654+ keys_to_delete . append (blob_key )
626655
627656 # Delete all writes for this thread
628657 writes_query = FilterQuery (
@@ -642,10 +671,19 @@ def delete_thread(self, thread_id: str) -> None:
642671 write_key = BaseRedisSaver ._make_redis_checkpoint_writes_key (
643672 storage_safe_thread_id , checkpoint_ns , checkpoint_id , task_id , idx
644673 )
645- pipeline . delete (write_key )
674+ keys_to_delete . append (write_key )
646675
647- # Execute all deletions
648- pipeline .execute ()
676+ # Execute all deletions based on cluster mode
677+ if self .cluster_mode :
678+ # For cluster mode, delete keys individually
679+ for key in keys_to_delete :
680+ self ._redis .delete (key )
681+ else :
682+ # For non-cluster mode, use pipeline for efficiency
683+ pipeline = self ._redis .pipeline ()
684+ for key in keys_to_delete :
685+ pipeline .delete (key )
686+ pipeline .execute ()
649687
650688
651689__all__ = [
0 commit comments