@@ -190,6 +190,53 @@ async def _detect_cluster_mode(self) -> None:
190190 logger .info ("Redis client is a standalone client" )
191191 self .cluster_mode = False
192192
193+ async def _apply_ttl_to_keys (
194+ self ,
195+ main_key : str ,
196+ related_keys : Optional [list [str ]] = None ,
197+ ttl_minutes : Optional [float ] = None ,
198+ ) -> Any :
199+ """Apply Redis native TTL to keys asynchronously.
200+
201+ Args:
202+ main_key: The primary Redis key
203+ related_keys: Additional Redis keys that should expire at the same time
204+ ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
205+
206+ Returns:
207+ Result of the Redis operation
208+ """
209+ if ttl_minutes is None :
210+ # Check if there's a default TTL in config
211+ if self .ttl_config and "default_ttl" in self .ttl_config :
212+ ttl_minutes = self .ttl_config .get ("default_ttl" )
213+
214+ if ttl_minutes is not None :
215+ ttl_seconds = int (ttl_minutes * 60 )
216+
217+ if self .cluster_mode :
218+ # For cluster mode, execute TTL operations individually
219+ await self ._redis .expire (main_key , ttl_seconds )
220+
221+ if related_keys :
222+ for key in related_keys :
223+ await self ._redis .expire (key , ttl_seconds )
224+
225+ return True
226+ else :
227+ # For non-cluster mode, use pipeline for efficiency
228+ pipeline = self ._redis .pipeline ()
229+
230+ # Set TTL for main key
231+ pipeline .expire (main_key , ttl_seconds )
232+
233+ # Set TTL for related keys
234+ if related_keys :
235+ for key in related_keys :
236+ pipeline .expire (key , ttl_seconds )
237+
238+ return await pipeline .execute ()
239+
193240 async def aget_tuple (self , config : RunnableConfig ) -> Optional [CheckpointTuple ]:
194241 """Get a checkpoint tuple from Redis asynchronously."""
195242 thread_id = config ["configurable" ]["thread_id" ]
@@ -263,29 +310,10 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
263310 write_keys = [safely_decode (key ) for key in write_keys ]
264311
265312 # Apply TTL to checkpoint, blob keys, and write keys
266- ttl_minutes = self .ttl_config .get ("default_ttl" )
267- if ttl_minutes is not None :
268- ttl_seconds = int (ttl_minutes * 60 )
269-
270- if self .cluster_mode :
271- # For cluster mode, execute TTL operations individually
272- await self ._redis .expire (checkpoint_key , ttl_seconds )
273-
274- # Combine blob keys and write keys for TTL refresh
275- all_related_keys = blob_keys + write_keys
276- for key in all_related_keys :
277- await self ._redis .expire (key , ttl_seconds )
278- else :
279- # For non-cluster mode, use pipeline for TTL operations
280- pipeline = self ._redis .pipeline ()
281- pipeline .expire (checkpoint_key , ttl_seconds )
282-
283- # Combine blob keys and write keys for TTL refresh
284- all_related_keys = blob_keys + write_keys
285- for key in all_related_keys :
286- pipeline .expire (key , ttl_seconds )
287-
288- await pipeline .execute ()
313+ all_related_keys = blob_keys + write_keys
314+ await self ._apply_ttl_to_keys (
315+ checkpoint_key , all_related_keys if all_related_keys else None
316+ )
289317
290318 # Fetch channel_values
291319 channel_values = await self .aget_channel_values (
@@ -801,14 +829,10 @@ async def aput_writes(
801829 and self .ttl_config
802830 and "default_ttl" in self .ttl_config
803831 ):
804- ttl_minutes = self .ttl_config .get ("default_ttl" )
805- ttl_seconds = int (ttl_minutes * 60 )
806-
807- # Use a new pipeline for TTL operations
808- ttl_pipeline = self ._redis .pipeline ()
809- for key in created_keys :
810- ttl_pipeline .expire (key , ttl_seconds )
811- await ttl_pipeline .execute ()
832+ await self ._apply_ttl_to_keys (
833+ created_keys [0 ],
834+ created_keys [1 :] if len (created_keys ) > 1 else None ,
835+ )
812836
813837 except asyncio .CancelledError :
814838 # Handle cancellation/interruption
0 commit comments