33import json
44import random
55from abc import abstractmethod
6- from collections .abc import Sequence
7- from typing import Any , Generic , List , Optional , cast
6+ from typing import Any , Dict , Generic , List , Optional , Sequence , Tuple , cast
87
98from langchain_core .runnables import RunnableConfig
109from langgraph .checkpoint .base import (
@@ -100,12 +99,16 @@ def __init__(
10099 redis_url : Optional [str ] = None ,
101100 * ,
102101 redis_client : Optional [RedisClientType ] = None ,
103- connection_args : Optional [dict [str , Any ]] = None ,
102+ connection_args : Optional [Dict [str , Any ]] = None ,
103+ ttl : Optional [Dict [str , Any ]] = None ,
104104 ) -> None :
105105 super ().__init__ (serde = JsonPlusRedisSerializer ())
106106 if redis_url is None and redis_client is None :
107107 raise ValueError ("Either redis_url or redis_client must be provided" )
108108
109+ # Store TTL configuration
110+ self .ttl_config = ttl
111+
109112 self .configure_client (
110113 redis_url = redis_url ,
111114 redis_client = redis_client ,
@@ -128,7 +131,7 @@ def configure_client(
128131 self ,
129132 redis_url : Optional [str ] = None ,
130133 redis_client : Optional [RedisClientType ] = None ,
131- connection_args : Optional [dict [str , Any ]] = None ,
134+ connection_args : Optional [Dict [str , Any ]] = None ,
132135 ) -> None :
133136 """Configure the Redis client."""
134137 pass
@@ -180,11 +183,46 @@ def setup(self) -> None:
180183 self .checkpoint_blobs_index .create (overwrite = False )
181184 self .checkpoint_writes_index .create (overwrite = False )
182185
186+ def _apply_ttl_to_keys (
187+ self ,
188+ main_key : str ,
189+ related_keys : Optional [List [str ]] = None ,
190+ ttl_minutes : Optional [float ] = None ,
191+ ) -> Any :
192+ """Apply Redis native TTL to keys.
193+
194+ Args:
195+ main_key: The primary Redis key
196+ related_keys: Additional Redis keys that should expire at the same time
197+ ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
198+
199+ Returns:
200+ Result of the Redis operation
201+ """
202+ if ttl_minutes is None :
203+ # Check if there's a default TTL in config
204+ if self .ttl_config and "default_ttl" in self .ttl_config :
205+ ttl_minutes = self .ttl_config .get ("default_ttl" )
206+
207+ if ttl_minutes is not None :
208+ ttl_seconds = int (ttl_minutes * 60 )
209+ pipeline = self ._redis .pipeline ()
210+
211+ # Set TTL for main key
212+ pipeline .expire (main_key , ttl_seconds )
213+
214+ # Set TTL for related keys
215+ if related_keys :
216+ for key in related_keys :
217+ pipeline .expire (key , ttl_seconds )
218+
219+ return pipeline .execute ()
220+
183221 def _load_checkpoint (
184222 self ,
185- checkpoint : dict [str , Any ],
186- channel_values : dict [str , Any ],
187- pending_sends : list [Any ],
223+ checkpoint : Dict [str , Any ],
224+ channel_values : Dict [str , Any ],
225+ pending_sends : List [Any ],
188226 ) -> Checkpoint :
189227 if not checkpoint :
190228 return {}
@@ -218,7 +256,7 @@ def _load_blobs(self, blob_values: dict[str, Any]) -> dict[str, Any]:
218256 if v ["type" ] != "empty"
219257 }
220258
221- def _get_type_and_blob (self , value : Any ) -> tuple [str , Optional [bytes ]]:
259+ def _get_type_and_blob (self , value : Any ) -> Tuple [str , Optional [bytes ]]:
222260 """Helper to get type and blob from a value."""
223261 t , b = self .serde .dumps_typed (value )
224262 return t , b
@@ -227,9 +265,9 @@ def _dump_blobs(
227265 self ,
228266 thread_id : str ,
229267 checkpoint_ns : str ,
230- values : dict [str , Any ],
268+ values : Dict [str , Any ],
231269 versions : ChannelVersions ,
232- ) -> list [ tuple [str , dict [str , Any ]]]:
270+ ) -> List [ Tuple [str , Dict [str , Any ]]]:
233271 """Convert blob data for Redis storage."""
234272 if not versions :
235273 return []
@@ -337,7 +375,7 @@ def _decode_blob(self, blob: str) -> bytes:
337375 # Handle both malformed base64 data and incorrect input types
338376 return blob .encode () if isinstance (blob , str ) else blob
339377
340- def _load_writes_from_redis (self , write_key : str ) -> list [ tuple [str , str , Any ]]:
378+ def _load_writes_from_redis (self , write_key : str ) -> List [ Tuple [str , str , Any ]]:
341379 """Load writes from Redis JSON storage by key."""
342380 if not write_key :
343381 return []
0 commit comments