diff --git a/README.md b/README.md index c2adbdc..ceb93dd 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,25 @@ with ShallowRedisSaver.from_conn_string("redis://localhost:6379") as checkpointe # ... rest of the implementation follows similar pattern ``` +## Redis Checkpoint TTL Support + +Both Redis checkpoint savers and stores support Time-To-Live (TTL) functionality for automatic key expiration: + +```python +# Configure TTL for checkpoint savers +ttl_config = { + "default_ttl": 60, # Default TTL in minutes + "refresh_on_read": True, # Refresh TTL when checkpoint is read +} + +# Use with any checkpoint saver implementation +with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as checkpointer: + checkpointer.setup() + # Use the checkpointer... +``` + +This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up. + ## Redis Stores Redis Stores provide a persistent key-value store with optional vector search capabilities. @@ -169,9 +188,19 @@ index_config = { "fields": ["text"], # Fields to index } -with RedisStore.from_conn_string("redis://localhost:6379", index=index_config) as store: +# With TTL configuration +ttl_config = { + "default_ttl": 60, # Default TTL in minutes + "refresh_on_read": True, # Refresh TTL when store entries are read +} + +with RedisStore.from_conn_string( + "redis://localhost:6379", + index=index_config, + ttl=ttl_config +) as store: store.setup() - # Use the store with vector search capabilities... + # Use the store with vector search and TTL capabilities... ``` ### Async Implementation @@ -180,7 +209,16 @@ with RedisStore.from_conn_string("redis://localhost:6379", index=index_config) a from langgraph.store.redis.aio import AsyncRedisStore async def main(): - async with AsyncRedisStore.from_conn_string("redis://localhost:6379") as store: + # TTL also works with async implementations + ttl_config = { + "default_ttl": 60, # Default TTL in minutes + "refresh_on_read": True, # Refresh TTL when store entries are read + } + + async with AsyncRedisStore.from_conn_string( + "redis://localhost:6379", + ttl=ttl_config + ) as store: await store.setup() # Use the store asynchronously... @@ -235,6 +273,16 @@ For Redis Stores with vector search: 1. **Store Index**: Main key-value store 2. **Vector Index**: Optional vector embeddings for similarity search +### TTL Implementation + +Both Redis checkpoint savers and stores leverage Redis's native key expiration: + +- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command +- **Automatic Cleanup**: Redis automatically removes expired keys +- **Configurable Default TTL**: Set a default TTL for all keys in minutes +- **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed +- **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes) + ## Contributing We welcome contributions! Here's how you can help: diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index b80caaa..1bdbaf0 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -253,15 +253,16 @@ def put( checkpoint_data["source"] = metadata["source"] checkpoint_data["step"] = metadata["step"] # type: ignore + # Create the checkpoint key + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + storage_safe_thread_id, + storage_safe_checkpoint_ns, + storage_safe_checkpoint_id, + ) + self.checkpoints_index.load( [checkpoint_data], - keys=[ - BaseRedisSaver._make_redis_checkpoint_key( - storage_safe_thread_id, - storage_safe_checkpoint_ns, - storage_safe_checkpoint_id, - ) - ], + keys=[checkpoint_key], ) # Store blob values. @@ -272,10 +273,16 @@ def put( new_versions, ) + blob_keys = [] if blobs: # Unzip the list of tuples into separate lists for keys and data keys, data = zip(*blobs) - self.checkpoint_blobs_index.load(list(data), keys=list(keys)) + blob_keys = list(keys) + self.checkpoint_blobs_index.load(list(data), keys=blob_keys) + + # Apply TTL to checkpoint and blob keys if configured + if self.ttl_config and "default_ttl" in self.ttl_config: + self._apply_ttl_to_keys(checkpoint_key, blob_keys) return next_config @@ -332,6 +339,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"]) doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"]) + # If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys + if self.ttl_config and self.ttl_config.get("refresh_on_read"): + # Get the checkpoint key + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + to_storage_safe_id(doc_thread_id), + to_storage_safe_str(doc_checkpoint_ns), + to_storage_safe_id(doc_checkpoint_id), + ) + + # Get all blob keys related to this checkpoint + from langgraph.checkpoint.redis.base import ( + CHECKPOINT_BLOB_PREFIX, + CHECKPOINT_WRITE_PREFIX, + ) + + # Get the blob keys + blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*" + blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)] + + # Also get checkpoint write keys that should have the same TTL + write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*" + write_keys = [key.decode() for key in self._redis.keys(write_key_pattern)] + + # Apply TTL to checkpoint, blob keys, and write keys + all_related_keys = blob_keys + write_keys + self._apply_ttl_to_keys(checkpoint_key, all_related_keys) + # Fetch channel_values channel_values = self.get_channel_values( thread_id=doc_thread_id, diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 324bc25..d512848 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -192,6 +192,45 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"]) doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"]) + # If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys + if self.ttl_config and self.ttl_config.get("refresh_on_read"): + # Get the checkpoint key + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + to_storage_safe_id(doc_thread_id), + to_storage_safe_str(doc_checkpoint_ns), + to_storage_safe_id(doc_checkpoint_id), + ) + + # Get all blob keys related to this checkpoint + from langgraph.checkpoint.redis.base import ( + CHECKPOINT_BLOB_PREFIX, + CHECKPOINT_WRITE_PREFIX, + ) + + # Get the blob keys + blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*" + blob_keys = await self._redis.keys(blob_key_pattern) + blob_keys = [key.decode() for key in blob_keys] + + # Also get checkpoint write keys that should have the same TTL + write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*" + write_keys = await self._redis.keys(write_key_pattern) + write_keys = [key.decode() for key in write_keys] + + # Apply TTL to checkpoint, blob keys, and write keys + ttl_minutes = self.ttl_config.get("default_ttl") + if ttl_minutes is not None: + ttl_seconds = int(ttl_minutes * 60) + pipeline = self._redis.pipeline() + pipeline.expire(checkpoint_key, ttl_seconds) + + # Combine blob keys and write keys for TTL refresh + all_related_keys = blob_keys + write_keys + for key in all_related_keys: + pipeline.expire(key, ttl_seconds) + + await pipeline.execute() + # Fetch channel_values channel_values = await self.aget_channel_values( thread_id=doc_thread_id, @@ -476,6 +515,22 @@ async def aput( # Execute all operations atomically await pipeline.execute() + # Apply TTL to checkpoint and blob keys if configured + if self.ttl_config and "default_ttl" in self.ttl_config: + all_keys = ( + [checkpoint_key] + [key for key, _ in blobs] + if blobs + else [checkpoint_key] + ) + ttl_minutes = self.ttl_config.get("default_ttl") + ttl_seconds = int(ttl_minutes * 60) + + # Use a new pipeline for TTL operations + ttl_pipeline = self._redis.pipeline() + for key in all_keys: + ttl_pipeline.expire(key, ttl_seconds) + await ttl_pipeline.execute() + return next_config except asyncio.CancelledError: @@ -575,6 +630,7 @@ async def aput_writes( # Determine if this is an upsert case upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes) + created_keys = [] # Add all write operations to the pipeline for write_obj in writes_objects: @@ -599,15 +655,28 @@ async def aput_writes( else: # Create new key await pipeline.json().set(key, "$", write_obj) + created_keys.append(key) else: # For non-upsert case, only set if key doesn't exist exists = await self._redis.exists(key) if not exists: await pipeline.json().set(key, "$", write_obj) + created_keys.append(key) # Execute all operations atomically await pipeline.execute() + # Apply TTL to newly created keys + if created_keys and self.ttl_config and "default_ttl" in self.ttl_config: + ttl_minutes = self.ttl_config.get("default_ttl") + ttl_seconds = int(ttl_minutes * 60) + + # Use a new pipeline for TTL operations + ttl_pipeline = self._redis.pipeline() + for key in created_keys: + ttl_pipeline.expire(key, ttl_seconds) + await ttl_pipeline.execute() + except asyncio.CancelledError: # Handle cancellation/interruption # Pipeline will be automatically discarded diff --git a/langgraph/checkpoint/redis/ashallow.py b/langgraph/checkpoint/redis/ashallow.py index c920da9..90a4560 100644 --- a/langgraph/checkpoint/redis/ashallow.py +++ b/langgraph/checkpoint/redis/ashallow.py @@ -108,11 +108,13 @@ def __init__( *, redis_client: Optional[AsyncRedis] = None, connection_args: Optional[dict[str, Any]] = None, + ttl: Optional[dict[str, Any]] = None, ) -> None: super().__init__( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) self.loop = asyncio.get_running_loop() @@ -149,12 +151,14 @@ async def from_conn_string( *, redis_client: Optional[AsyncRedis] = None, connection_args: Optional[dict[str, Any]] = None, + ttl: Optional[dict[str, Any]] = None, ) -> AsyncIterator[AsyncShallowRedisSaver]: """Create a new AsyncShallowRedisSaver instance.""" async with cls( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) as saver: yield saver @@ -279,6 +283,22 @@ async def aput( # Execute all operations atomically await pipeline.execute() + # Apply TTL to checkpoint and blob keys if configured + if self.ttl_config and "default_ttl" in self.ttl_config: + # Prepare the list of keys to apply TTL + ttl_keys = [checkpoint_key] + if blobs: + ttl_keys.extend([key for key, _ in blobs]) + + # Apply TTL to all keys + ttl_minutes = self.ttl_config.get("default_ttl") + ttl_seconds = int(ttl_minutes * 60) + + ttl_pipeline = self._redis.pipeline() + for key in ttl_keys: + ttl_pipeline.expire(key, ttl_seconds) + await ttl_pipeline.execute() + return next_config except asyncio.CancelledError: @@ -389,6 +409,35 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: doc = results.docs[0] + # If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys + if self.ttl_config and self.ttl_config.get("refresh_on_read"): + thread_id = getattr(doc, "thread_id", "") + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + + # Get the checkpoint key + checkpoint_key = AsyncShallowRedisSaver._make_shallow_redis_checkpoint_key( + thread_id, checkpoint_ns + ) + + # Get all blob keys related to this checkpoint + blob_key_pattern = ( + AsyncShallowRedisSaver._make_shallow_redis_checkpoint_blob_key_pattern( + thread_id, checkpoint_ns + ) + ) + blob_keys = await self._redis.keys(blob_key_pattern) + blob_keys = [key.decode() for key in blob_keys] + + # Apply TTL + ttl_minutes = self.ttl_config.get("default_ttl") + if ttl_minutes is not None: + ttl_seconds = int(ttl_minutes * 60) + pipeline = self._redis.pipeline() + pipeline.expire(checkpoint_key, ttl_seconds) + for key in blob_keys: + pipeline.expire(key, ttl_seconds) + await pipeline.execute() + checkpoint = json.loads(doc["$.checkpoint"]) # Fetch channel_values diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index ce268cc..cd88e24 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -102,6 +102,16 @@ def __init__( connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, ) -> None: + """Initialize Redis-backed checkpoint saver. + + Args: + redis_url: Redis connection URL + redis_client: Redis client instance to use (alternative to redis_url) + connection_args: Additional arguments for Redis connection + ttl: Optional TTL configuration dict with optional keys: + - default_ttl: TTL in minutes for all checkpoint keys + - refresh_on_read: Whether to refresh TTL on reads + """ super().__init__(serde=JsonPlusRedisSerializer()) if redis_url is None and redis_client is None: raise ValueError("Either redis_url or redis_client must be provided") @@ -183,10 +193,32 @@ def setup(self) -> None: self.checkpoint_blobs_index.create(overwrite=False) self.checkpoint_writes_index.create(overwrite=False) + def _load_checkpoint( + self, + checkpoint: Dict[str, Any], + channel_values: Dict[str, Any], + pending_sends: List[Any], + ) -> Checkpoint: + if not checkpoint: + return {} + + loaded = json.loads(checkpoint) # type: ignore[arg-type] + + # Note: TTL refresh is now handled in get_tuple() to ensure it works + # with all Redis operations, not just internal deserialization + + return { + **loaded, + "pending_sends": [ + self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or [] + ], + "channel_values": channel_values, + } + def _apply_ttl_to_keys( self, main_key: str, - related_keys: Optional[List[str]] = None, + related_keys: Optional[list[str]] = None, ttl_minutes: Optional[float] = None, ) -> Any: """Apply Redis native TTL to keys. @@ -218,25 +250,6 @@ def _apply_ttl_to_keys( return pipeline.execute() - def _load_checkpoint( - self, - checkpoint: Dict[str, Any], - channel_values: Dict[str, Any], - pending_sends: List[Any], - ) -> Checkpoint: - if not checkpoint: - return {} - - loaded = json.loads(checkpoint) # type: ignore[arg-type] - - return { - **loaded, - "pending_sends": [ - self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or [] - ], - "channel_values": channel_values, - } - def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: """Convert checkpoint to Redis format.""" type_, data = self.serde.dumps_typed(checkpoint) @@ -436,6 +449,9 @@ def put_writes( # For each write, check existence and then perform appropriate operation with self._redis.json().pipeline(transaction=False) as pipeline: + # Keep track of keys we're creating + created_keys = [] + for write_obj in writes_objects: key = self._make_redis_checkpoint_writes_key( thread_id, @@ -458,13 +474,21 @@ def put_writes( else: # For new records, set the complete object pipeline.set(key, "$", write_obj) # type: ignore[arg-type] + created_keys.append(key) else: # INSERT case - only insert if doesn't exist if not key_exists: pipeline.set(key, "$", write_obj) # type: ignore[arg-type] + created_keys.append(key) pipeline.execute() + # Apply TTL to newly created keys + if created_keys and self.ttl_config and "default_ttl" in self.ttl_config: + self._apply_ttl_to_keys( + created_keys[0], created_keys[1:] if len(created_keys) > 1 else None + ) + def _load_pending_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> List[PendingWrite]: diff --git a/langgraph/checkpoint/redis/shallow.py b/langgraph/checkpoint/redis/shallow.py index 650f34c..ad8bcd4 100644 --- a/langgraph/checkpoint/redis/shallow.py +++ b/langgraph/checkpoint/redis/shallow.py @@ -82,11 +82,13 @@ def __init__( *, redis_client: Optional[Redis] = None, connection_args: Optional[dict[str, Any]] = None, + ttl: Optional[dict[str, Any]] = None, ) -> None: super().__init__( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) @classmethod @@ -97,6 +99,7 @@ def from_conn_string( *, redis_client: Optional[Redis] = None, connection_args: Optional[dict[str, Any]] = None, + ttl: Optional[dict[str, Any]] = None, ) -> Iterator[ShallowRedisSaver]: """Create a new ShallowRedisSaver instance.""" saver: Optional[ShallowRedisSaver] = None @@ -105,6 +108,7 @@ def from_conn_string( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, + ttl=ttl, ) yield saver finally: @@ -203,10 +207,19 @@ def put( new_versions, ) + blob_keys = [] if blobs: # Unzip the list of tuples into separate lists for keys and data keys, data = zip(*blobs) - self.checkpoint_blobs_index.load(list(data), keys=list(keys)) + blob_keys = list(keys) + self.checkpoint_blobs_index.load(list(data), keys=blob_keys) + + # Apply TTL to checkpoint and blob keys if configured + checkpoint_key = ShallowRedisSaver._make_shallow_redis_checkpoint_key( + thread_id, checkpoint_ns + ) + if self.ttl_config and "default_ttl" in self.ttl_config: + self._apply_ttl_to_keys(checkpoint_key, blob_keys) return next_config @@ -358,6 +371,27 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: doc = results.docs[0] + # If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys + if self.ttl_config and self.ttl_config.get("refresh_on_read"): + thread_id = getattr(doc, "thread_id", "") + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + + # Get the checkpoint key + checkpoint_key = ShallowRedisSaver._make_shallow_redis_checkpoint_key( + thread_id, checkpoint_ns + ) + + # Get all blob keys related to this checkpoint + blob_key_pattern = ( + ShallowRedisSaver._make_shallow_redis_checkpoint_blob_key_pattern( + thread_id, checkpoint_ns + ) + ) + blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)] + + # Apply TTL + self._apply_ttl_to_keys(checkpoint_key, blob_keys) + checkpoint = json.loads(doc["$.checkpoint"]) # Fetch channel_values diff --git a/tests/test_checkpoint_ttl.py b/tests/test_checkpoint_ttl.py new file mode 100644 index 0000000..4b1fb38 --- /dev/null +++ b/tests/test_checkpoint_ttl.py @@ -0,0 +1,378 @@ +"""Tests for TTL functionality with RedisSaver.""" + +from __future__ import annotations + +import os +import time +from typing import Any, Dict, Generator, Iterator, Optional, TypedDict, cast + +import pytest +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, CheckpointTuple +from langgraph.graph import END, StateGraph +from redis import Redis + +from langgraph.checkpoint.redis import RedisSaver +from langgraph.checkpoint.redis.util import to_storage_safe_id + + +class State(TypedDict): + """Simple state with count.""" + + count: int + + +@pytest.fixture(scope="function") +def redis_url(redis_container) -> str: + """Get the Redis URL from the container.""" + host, port = redis_container.get_service_host_and_port("redis", 6379) + return f"redis://{host}:{port}" + + +@pytest.fixture(scope="function") +def redis_client(redis_url: str) -> Generator[Redis, None, None]: + """Create a Redis client for testing.""" + client = Redis.from_url(redis_url) + try: + yield client + finally: + # Clean up any test keys + keys = client.keys("checkpoint:test_ttl*") + if keys: + client.delete(*keys) + client.close() + + +@pytest.fixture(scope="function") +def ttl_checkpoint_saver(redis_client: Redis) -> Generator[RedisSaver, None, None]: + """Create a RedisSaver instance with TTL support.""" + saver = RedisSaver( + redis_client=redis_client, + ttl={ + "default_ttl": 0.1, + "refresh_on_read": True, + }, # 0.1 minutes = 6 seconds TTL + ) + saver.setup() + yield saver + + +def test_ttl_config_in_constructor(redis_client: Redis) -> None: + """Test that TTL config can be passed through constructor.""" + saver = RedisSaver( + redis_client=redis_client, + ttl={"default_ttl": 10, "refresh_on_read": True}, + ) + assert saver.ttl_config is not None + assert saver.ttl_config.get("default_ttl") == 10 + assert saver.ttl_config.get("refresh_on_read") is True + + +def test_checkpoint_expires(redis_client: Redis) -> None: + """Test that a checkpoint expires after the TTL period.""" + try: + # Create unique identifiers to avoid test collisions + unique_prefix = f"expires_test_{int(time.time())}" + + # Create a saver with TTL + ttl_checkpoint_saver = RedisSaver( + redis_client=redis_client, + ttl={ + "default_ttl": 0.1, # 0.1 minutes = 6 seconds TTL + "refresh_on_read": True, + }, + ) + ttl_checkpoint_saver.setup() + + # Create a checkpoint with unique thread ID + thread_id = f"{unique_prefix}_thread" + checkpoint_ns = f"{unique_prefix}_ns" + checkpoint_id = f"{unique_prefix}_checkpoint" + + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + checkpoint: Checkpoint = { + "id": checkpoint_id, + "channel_values": {"test_channel": "test_value"}, + "channel_versions": {"test_channel": "1.0"}, + "versions_seen": {}, + "pending_sends": [], + } + + metadata: CheckpointMetadata = { + "source": "test", + "step": 1, + } + + # Save the checkpoint (with default TTL of 0.1 minutes = 6 seconds) + ttl_checkpoint_saver.put(config, checkpoint, metadata, {"test_channel": "1.0"}) + + # Verify checkpoint exists immediately after creation + initial_result = ttl_checkpoint_saver.get_tuple(config) + assert initial_result is not None, "Checkpoint should exist after creation" + + # Wait for TTL to expire (plus a small buffer) + time.sleep(7) # 7 seconds > 6 seconds TTL + + # Verify checkpoint no longer exists + result = ttl_checkpoint_saver.get_tuple(config) + assert result is None, "Checkpoint with TTL should expire" + finally: + # Clean up + keys = redis_client.keys(f"checkpoint:*{thread_id}*") + if keys: + redis_client.delete(*keys) + # Do not close the client as it's provided by the fixture + + +def test_ttl_refresh_on_read(redis_client: Redis) -> None: + """Test that TTL is refreshed when reading a checkpoint if refresh_on_read is enabled.""" + try: + # Create unique identifiers to avoid test collisions + unique_prefix = f"refresh_test_{int(time.time())}" + + # Create a saver with TTL and refresh_on_read enabled + ttl_checkpoint_saver = RedisSaver( + redis_client=redis_client, + ttl={ + "default_ttl": 0.1, # 0.1 minutes = 6 seconds TTL + "refresh_on_read": True, + }, + ) + ttl_checkpoint_saver.setup() + + # Create a checkpoint with unique thread ID + thread_id = f"{unique_prefix}_thread" + checkpoint_ns = f"{unique_prefix}_ns" + checkpoint_id = f"{unique_prefix}_checkpoint" + + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + checkpoint: Checkpoint = { + "id": checkpoint_id, + "channel_values": {"test_channel": "test_value"}, + "channel_versions": {"test_channel": "1.0"}, + "versions_seen": {}, + "pending_sends": [], + } + + metadata: CheckpointMetadata = { + "source": "test", + "step": 1, + } + + # Save the checkpoint (with default TTL of 0.1 minutes = 6 seconds) + ttl_checkpoint_saver.put(config, checkpoint, metadata, {"test_channel": "1.0"}) + + # Verify checkpoint exists immediately after creation + initial_result = ttl_checkpoint_saver.get_tuple(config) + assert initial_result is not None, "Checkpoint should exist after creation" + + # Wait for 3 seconds (less than TTL) + time.sleep(3) + + # Read the checkpoint (should refresh TTL) + ttl_checkpoint_saver.get_tuple(config) + + # Wait another 2 seconds (would be 5 seconds total, less than original TTL) + time.sleep(2) + + # Wait extra time to account for any test delays + time.sleep(1) + + # Checkpoint should still exist because TTL was refreshed + result = ttl_checkpoint_saver.get_tuple(config) + assert result is not None, "Checkpoint should still exist after TTL refresh" + + # Wait for TTL to expire again + time.sleep(7) + + # Verify checkpoint no longer exists + result = ttl_checkpoint_saver.get_tuple(config) + assert result is None, "Checkpoint should expire after refreshed TTL" + finally: + # Clean up + keys = redis_client.keys(f"checkpoint:*{thread_id}*") + if keys: + redis_client.delete(*keys) + # Do not close the client as it's provided by the fixture + + +def test_put_writes_with_ttl(redis_client: Redis) -> None: + """Test that writes also expire with TTL.""" + try: + # Create unique identifiers to avoid test collisions + unique_prefix = f"writes_test_{int(time.time())}" + + # Create a saver with TTL + ttl_checkpoint_saver = RedisSaver( + redis_client=redis_client, + ttl={ + "default_ttl": 0.1, # 0.1 minutes = 6 seconds TTL + "refresh_on_read": True, + }, + ) + ttl_checkpoint_saver.setup() + + # Create a checkpoint with unique thread ID + thread_id = f"{unique_prefix}_thread" + checkpoint_ns = f"{unique_prefix}_ns" + checkpoint_id = f"{unique_prefix}_checkpoint" + + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + # Create some writes + ttl_checkpoint_saver.put_writes( + config, [("test_channel", "test_value")], "test_task_id" + ) + + # Verify writes exist immediately after creation + initial_writes = ttl_checkpoint_saver._load_pending_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + assert len(initial_writes) > 0, "Writes should exist after creation" + + # Wait for TTL to expire + time.sleep(7) # 7 seconds > 6 seconds TTL + + # Verify writes no longer exist + writes = ttl_checkpoint_saver._load_pending_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + assert len(writes) == 0, "Writes with TTL should expire" + finally: + # Clean up + keys = redis_client.keys(f"checkpoint:*{thread_id}*") + if keys: + redis_client.delete(*keys) + # Do not close the client as it's provided by the fixture + + +def test_no_ttl_when_not_configured(redis_client: Redis) -> None: + """Test that keys don't expire when TTL is not configured.""" + try: + # Create unique identifiers to avoid test collisions + unique_prefix = f"no_ttl_test_{int(time.time())}" + + # Create a saver without TTL + saver = RedisSaver(redis_client=redis_client) + saver.setup() + + # Create a checkpoint with unique thread ID + thread_id = f"{unique_prefix}_thread" + checkpoint_ns = f"{unique_prefix}_ns" + checkpoint_id = f"{unique_prefix}_checkpoint" + + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + checkpoint: Checkpoint = { + "id": checkpoint_id, + "channel_values": {"test_channel": "test_value"}, + "channel_versions": {"test_channel": "1.0"}, + "versions_seen": {}, + "pending_sends": [], + } + + metadata: CheckpointMetadata = { + "source": "test", + "step": 1, + } + + # Save the checkpoint (no TTL configured) + saver.put(config, checkpoint, metadata, {"test_channel": "1.0"}) + + # Verify checkpoint exists immediately after creation + initial_result = saver.get_tuple(config) + assert initial_result is not None, "Checkpoint should exist after creation" + + # Wait for the same amount of time that would cause TTL expiration + time.sleep(7) + + # Verify checkpoint still exists + result = saver.get_tuple(config) + assert result is not None, "Checkpoint without TTL should not expire" + finally: + # Clean up + keys = redis_client.keys(f"checkpoint:*{thread_id}*") + if keys: + redis_client.delete(*keys) + # Do not close the client as it's provided by the fixture + + +def test_simple_graph_with_ttl(redis_client: Redis) -> None: + """Test a simple graph with TTL configuration.""" + # Use an isolated Redis client to prevent interference from parallel tests + unique_prefix = f"graph_test_{int(time.time())}" + thread_id = f"{unique_prefix}_thread" + + def add_one(state): + """Add one to the state.""" + state["count"] = state.get("count", 0) + 1 + return state + + # Define a simple graph + builder = StateGraph(State) + builder.add_node("add_one", add_one) + builder.set_entry_point("add_one") + builder.set_finish_point("add_one") + + try: + # Create a checkpointer with TTL + with RedisSaver.from_conn_string( + redis_client=redis_client, + ttl={"default_ttl": 0.1, "refresh_on_read": True}, # 6 seconds TTL + ) as checkpointer: + checkpointer.setup() + + # Compile the graph with the checkpointer + graph = builder.compile(checkpointer=checkpointer) + + # Use the graph with a specific thread_id + config = {"configurable": {"thread_id": thread_id}} + + # Initial run + result = graph.invoke({"count": 0}, config=config) + assert result["count"] == 1, "Initial count should be 1" + + # Run again immediately - should continue from checkpoint + result = graph.invoke({}, config=config) + assert result["count"] == 2, "Count should increment to 2 from checkpoint" + + # Wait for TTL to expire + time.sleep(7) # Wait longer than the 6 second TTL + + # Run again - should start from beginning since checkpoint expired + result = graph.invoke({}, config=config) + assert ( + result["count"] == 1 + ), "Count should reset to 1 after checkpoint expired" + finally: + # Clean up + keys = redis_client.keys(f"checkpoint:*{thread_id}*") + if keys: + redis_client.delete(*keys) + # Do not close the client as it's provided by the fixture