diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index 7b9c464..b80caaa 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -1,9 +1,8 @@ from __future__ import annotations import json -from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, List, Optional, Tuple, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( @@ -42,19 +41,21 @@ def __init__( redis_url: Optional[str] = None, *, redis_client: Optional[Redis] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, + ttl: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[Redis] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" self._owns_its_client = redis_client is None @@ -395,7 +396,8 @@ def from_conn_string( redis_url: Optional[str] = None, *, redis_client: Optional[Redis] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, + ttl: Optional[Dict[str, Any]] = None, ) -> Iterator[RedisSaver]: """Create a new RedisSaver instance.""" saver: Optional[RedisSaver] = None @@ -404,6 +406,7 @@ def from_conn_string( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) yield saver @@ -414,7 +417,7 @@ def from_conn_string( def get_channel_values( self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = "" - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Retrieve channel_values dictionary with properly constructed message objects.""" storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 23228b6..324bc25 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -5,11 +5,10 @@ import asyncio import json import os -from collections.abc import AsyncIterator from contextlib import asynccontextmanager from functools import partial from types import TracebackType -from typing import Any, List, Optional, Sequence, Tuple, Type, cast +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( @@ -42,7 +41,7 @@ async def _write_obj_tx( pipe: Pipeline, key: str, - write_obj: dict[str, Any], + write_obj: Dict[str, Any], upsert_case: bool, ) -> None: exists: int = await pipe.exists(key) @@ -73,12 +72,14 @@ def __init__( redis_url: Optional[str] = None, *, redis_client: Optional[AsyncRedis] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, + ttl: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) self.loop = asyncio.get_running_loop() @@ -86,7 +87,7 @@ def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[AsyncRedis] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" self._owns_its_client = redis_client is None @@ -706,18 +707,20 @@ async def from_conn_string( redis_url: Optional[str] = None, *, redis_client: Optional[AsyncRedis] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, + ttl: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[AsyncRedisSaver]: async with cls( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) as saver: yield saver async def aget_channel_values( self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = "" - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Retrieve channel_values dictionary with properly constructed message objects.""" storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) @@ -767,7 +770,7 @@ async def aget_channel_values( async def _aload_pending_sends( self, thread_id: str, checkpoint_ns: str = "", parent_checkpoint_id: str = "" - ) -> list[tuple[str, bytes]]: + ) -> List[Tuple[str, bytes]]: """Load pending sends for a parent checkpoint. Args: diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index d6c7ff7..ce268cc 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -3,8 +3,7 @@ import json import random from abc import abstractmethod -from collections.abc import Sequence -from typing import Any, Generic, List, Optional, cast +from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, cast from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( @@ -100,12 +99,16 @@ def __init__( redis_url: Optional[str] = None, *, redis_client: Optional[RedisClientType] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, + ttl: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(serde=JsonPlusRedisSerializer()) if redis_url is None and redis_client is None: raise ValueError("Either redis_url or redis_client must be provided") + # Store TTL configuration + self.ttl_config = ttl + self.configure_client( redis_url=redis_url, redis_client=redis_client, @@ -128,7 +131,7 @@ def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[RedisClientType] = None, - connection_args: Optional[dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" pass @@ -180,11 +183,46 @@ def setup(self) -> None: self.checkpoint_blobs_index.create(overwrite=False) self.checkpoint_writes_index.create(overwrite=False) + def _apply_ttl_to_keys( + self, + main_key: str, + related_keys: Optional[List[str]] = None, + ttl_minutes: Optional[float] = None, + ) -> Any: + """Apply Redis native TTL to keys. + + Args: + main_key: The primary Redis key + related_keys: Additional Redis keys that should expire at the same time + ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided + + Returns: + Result of the Redis operation + """ + if ttl_minutes is None: + # Check if there's a default TTL in config + if self.ttl_config and "default_ttl" in self.ttl_config: + ttl_minutes = self.ttl_config.get("default_ttl") + + if ttl_minutes is not None: + ttl_seconds = int(ttl_minutes * 60) + pipeline = self._redis.pipeline() + + # Set TTL for main key + pipeline.expire(main_key, ttl_seconds) + + # Set TTL for related keys + if related_keys: + for key in related_keys: + pipeline.expire(key, ttl_seconds) + + return pipeline.execute() + def _load_checkpoint( self, - checkpoint: dict[str, Any], - channel_values: dict[str, Any], - pending_sends: list[Any], + checkpoint: Dict[str, Any], + channel_values: Dict[str, Any], + pending_sends: List[Any], ) -> Checkpoint: if not checkpoint: return {} @@ -218,7 +256,7 @@ def _load_blobs(self, blob_values: dict[str, Any]) -> dict[str, Any]: if v["type"] != "empty" } - def _get_type_and_blob(self, value: Any) -> tuple[str, Optional[bytes]]: + def _get_type_and_blob(self, value: Any) -> Tuple[str, Optional[bytes]]: """Helper to get type and blob from a value.""" t, b = self.serde.dumps_typed(value) return t, b @@ -227,9 +265,9 @@ def _dump_blobs( self, thread_id: str, checkpoint_ns: str, - values: dict[str, Any], + values: Dict[str, Any], versions: ChannelVersions, - ) -> list[tuple[str, dict[str, Any]]]: + ) -> List[Tuple[str, Dict[str, Any]]]: """Convert blob data for Redis storage.""" if not versions: return [] @@ -337,7 +375,7 @@ def _decode_blob(self, blob: str) -> bytes: # Handle both malformed base64 data and incorrect input types return blob.encode() if isinstance(blob, str) else blob - def _load_writes_from_redis(self, write_key: str) -> list[tuple[str, str, Any]]: + def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]: """Load writes from Redis JSON storage by key.""" if not write_key: return []