Skip to content

Commit b186a39

Browse files
committed
feat: add TTL removal support for pinning checkpoints (#66)
Add support for removing TTL from Redis checkpoints to make them persistent. This enables "pinning" specific threads that should never expire while allowing others to be cleaned up automatically. Changes: - Add support for `ttl_minutes=-1` parameter to trigger Redis PERSIST command - Implement TTL removal in both sync and async checkpoint savers - Apply PERSIST to main key and all related keys (blobs, writes) - Add comprehensive test coverage for TTL removal functionality - Update README with documentation for the pinning feature
1 parent 48396d0 commit b186a39

File tree

4 files changed

+320
-2
lines changed

4 files changed

+320
-2
lines changed

README.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,32 @@ with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as ch
249249
# Use the checkpointer...
250250
```
251251

252-
This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up.
252+
### Removing TTL (Pinning Threads)
253+
254+
You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire:
255+
256+
```python
257+
from langgraph.checkpoint.redis import RedisSaver
258+
259+
# Create saver with default TTL
260+
saver = RedisSaver.from_conn_string("redis://localhost:6379", ttl={"default_ttl": 60})
261+
saver.setup()
262+
263+
# Save a checkpoint
264+
config = {"configurable": {"thread_id": "important-thread", "checkpoint_ns": ""}}
265+
saved_config = saver.put(config, checkpoint, metadata, {})
266+
267+
# Remove TTL from the checkpoint to make it persistent
268+
checkpoint_id = saved_config["configurable"]["checkpoint_id"]
269+
checkpoint_key = f"checkpoint:important-thread:__empty__:{checkpoint_id}"
270+
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
271+
272+
# The checkpoint is now persistent and won't expire
273+
```
274+
275+
When no TTL configuration is provided, checkpoints are persistent by default (no expiration).
276+
277+
This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up while keeping important data persistent.
253278

254279
## Redis Stores
255280

@@ -370,11 +395,13 @@ For Redis Stores with vector search:
370395

371396
Both Redis checkpoint savers and stores leverage Redis's native key expiration:
372397

373-
- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command
398+
- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command for setting TTL
399+
- **TTL Removal**: Uses Redis's `PERSIST` command to remove TTL (with `ttl_minutes=-1`)
374400
- **Automatic Cleanup**: Redis automatically removes expired keys
375401
- **Configurable Default TTL**: Set a default TTL for all keys in minutes
376402
- **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed
377403
- **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes)
404+
- **Persistent by Default**: When no TTL is configured, keys are persistent (no expiration)
378405

379406
## Contributing
380407

langgraph/checkpoint/redis/aio.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ async def _apply_ttl_to_keys(
295295
main_key: The primary Redis key
296296
related_keys: Additional Redis keys that should expire at the same time
297297
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
298+
Use -1 to remove TTL (make keys persistent)
298299
299300
Returns:
300301
Result of the Redis operation
@@ -305,6 +306,32 @@ async def _apply_ttl_to_keys(
305306
ttl_minutes = self.ttl_config.get("default_ttl")
306307

307308
if ttl_minutes is not None:
309+
# Special case: -1 means remove TTL (make persistent)
310+
if ttl_minutes == -1:
311+
if self.cluster_mode:
312+
# For cluster mode, execute PERSIST operations individually
313+
await self._redis.persist(main_key)
314+
315+
if related_keys:
316+
for key in related_keys:
317+
await self._redis.persist(key)
318+
319+
return True
320+
else:
321+
# For non-cluster mode, use pipeline for efficiency
322+
pipeline = self._redis.pipeline()
323+
324+
# Remove TTL for main key
325+
pipeline.persist(main_key)
326+
327+
# Remove TTL for related keys
328+
if related_keys:
329+
for key in related_keys:
330+
pipeline.persist(key)
331+
332+
return await pipeline.execute()
333+
334+
# Regular TTL setting
308335
ttl_seconds = int(ttl_minutes * 60)
309336

310337
if self.cluster_mode:

langgraph/checkpoint/redis/base.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _apply_ttl_to_keys(
238238
main_key: The primary Redis key
239239
related_keys: Additional Redis keys that should expire at the same time
240240
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
241+
Use -1 to remove TTL (make keys persistent)
241242
242243
Returns:
243244
Result of the Redis operation
@@ -248,6 +249,35 @@ def _apply_ttl_to_keys(
248249
ttl_minutes = self.ttl_config.get("default_ttl")
249250

250251
if ttl_minutes is not None:
252+
# Special case: -1 means remove TTL (make persistent)
253+
if ttl_minutes == -1:
254+
# Check if cluster mode is detected (for sync checkpoint savers)
255+
cluster_mode = getattr(self, "cluster_mode", False)
256+
257+
if cluster_mode:
258+
# For cluster mode, execute PERSIST operations individually
259+
self._redis.persist(main_key)
260+
261+
if related_keys:
262+
for key in related_keys:
263+
self._redis.persist(key)
264+
265+
return True
266+
else:
267+
# For non-cluster mode, use pipeline for efficiency
268+
pipeline = self._redis.pipeline()
269+
270+
# Remove TTL for main key
271+
pipeline.persist(main_key)
272+
273+
# Remove TTL for related keys
274+
if related_keys:
275+
for key in related_keys:
276+
pipeline.persist(key)
277+
278+
return pipeline.execute()
279+
280+
# Regular TTL setting
251281
ttl_seconds = int(ttl_minutes * 60)
252282

253283
# Check if cluster mode is detected (for sync checkpoint savers)

tests/test_ttl_removal.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""Tests for TTL removal feature (issue #66)."""
2+
3+
import time
4+
from uuid import uuid4
5+
6+
import pytest
7+
from langgraph.checkpoint.base import create_checkpoint, empty_checkpoint
8+
9+
from langgraph.checkpoint.redis import AsyncRedisSaver, RedisSaver
10+
11+
12+
def test_ttl_removal_with_negative_one(redis_url: str) -> None:
13+
"""Test that ttl_minutes=-1 removes TTL from keys."""
14+
saver = RedisSaver(redis_url, ttl={"default_ttl": 1}) # 1 minute default TTL
15+
saver.setup()
16+
17+
thread_id = str(uuid4())
18+
checkpoint = create_checkpoint(
19+
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
20+
)
21+
checkpoint["channel_values"]["messages"] = ["test"]
22+
23+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
24+
25+
# Save checkpoint (will have TTL)
26+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
27+
28+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
29+
30+
# Verify TTL is set
31+
ttl = saver._redis.ttl(checkpoint_key)
32+
assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}"
33+
34+
# Remove TTL using -1
35+
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
36+
37+
# Verify TTL is removed
38+
ttl_after = saver._redis.ttl(checkpoint_key)
39+
assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1"
40+
41+
42+
def test_ttl_removal_with_related_keys(redis_url: str) -> None:
43+
"""Test that TTL removal works for main key and related keys."""
44+
saver = RedisSaver(redis_url, ttl={"default_ttl": 1})
45+
saver.setup()
46+
47+
thread_id = str(uuid4())
48+
49+
# Create a checkpoint with writes (to have related keys)
50+
checkpoint = create_checkpoint(
51+
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
52+
)
53+
checkpoint["channel_values"]["messages"] = ["test"]
54+
55+
config = {
56+
"configurable": {
57+
"thread_id": thread_id,
58+
"checkpoint_ns": "",
59+
"checkpoint_id": "test-checkpoint",
60+
}
61+
}
62+
63+
# Save checkpoint and writes
64+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
65+
saver.put_writes(
66+
saved_config, [("channel1", "value1"), ("channel2", "value2")], "task-1"
67+
)
68+
69+
# Get the keys
70+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
71+
write_key1 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:0"
72+
write_key2 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:1"
73+
74+
# All keys should have TTL
75+
assert 50 <= saver._redis.ttl(checkpoint_key) <= 60
76+
assert 50 <= saver._redis.ttl(write_key1) <= 60
77+
assert 50 <= saver._redis.ttl(write_key2) <= 60
78+
79+
# Remove TTL from all keys
80+
saver._apply_ttl_to_keys(checkpoint_key, [write_key1, write_key2], ttl_minutes=-1)
81+
82+
# All keys should be persistent
83+
assert saver._redis.ttl(checkpoint_key) == -1
84+
assert saver._redis.ttl(write_key1) == -1
85+
assert saver._redis.ttl(write_key2) == -1
86+
87+
88+
def test_no_ttl_means_persistent(redis_url: str) -> None:
89+
"""Test that no TTL configuration means keys are persistent."""
90+
# Create saver with no TTL config
91+
saver = RedisSaver(redis_url) # No TTL config
92+
saver.setup()
93+
94+
thread_id = str(uuid4())
95+
checkpoint = create_checkpoint(
96+
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
97+
)
98+
checkpoint["channel_values"]["messages"] = ["test"]
99+
100+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
101+
102+
# Save checkpoint
103+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
104+
105+
# Check TTL
106+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
107+
ttl = saver._redis.ttl(checkpoint_key)
108+
109+
# Should be -1 (persistent) when no TTL configured
110+
assert ttl == -1, "Key should be persistent when no TTL configured"
111+
112+
113+
def test_ttl_removal_preserves_data(redis_url: str) -> None:
114+
"""Test that removing TTL doesn't affect the data."""
115+
saver = RedisSaver(redis_url, ttl={"default_ttl": 1})
116+
saver.setup()
117+
118+
thread_id = str(uuid4())
119+
checkpoint = create_checkpoint(
120+
checkpoint=empty_checkpoint(), channels={"messages": ["original data"]}, step=1
121+
)
122+
checkpoint["channel_values"]["messages"] = ["original data"]
123+
124+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
125+
126+
# Save checkpoint
127+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
128+
129+
# Load data before TTL removal
130+
loaded_before = saver.get_tuple(saved_config)
131+
assert loaded_before.checkpoint["channel_values"]["messages"] == ["original data"]
132+
133+
# Remove TTL
134+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
135+
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
136+
137+
# Load data after TTL removal
138+
loaded_after = saver.get_tuple(saved_config)
139+
assert loaded_after.checkpoint["channel_values"]["messages"] == ["original data"]
140+
141+
# Verify TTL is removed
142+
assert saver._redis.ttl(checkpoint_key) == -1
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_async_ttl_removal(redis_url: str) -> None:
147+
"""Test TTL removal with async saver."""
148+
async with AsyncRedisSaver.from_conn_string(
149+
redis_url, ttl={"default_ttl": 1}
150+
) as saver:
151+
thread_id = str(uuid4())
152+
checkpoint = create_checkpoint(
153+
checkpoint=empty_checkpoint(), channels={"messages": ["async test"]}, step=1
154+
)
155+
checkpoint["channel_values"]["messages"] = ["async test"]
156+
157+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
158+
159+
# Save checkpoint
160+
saved_config = await saver.aput(
161+
config, checkpoint, {"source": "test", "step": 1}, {}
162+
)
163+
164+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
165+
166+
# Verify TTL is set
167+
ttl = await saver._redis.ttl(checkpoint_key)
168+
assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}"
169+
170+
# Remove TTL using -1
171+
await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
172+
173+
# Verify TTL is removed
174+
ttl_after = await saver._redis.ttl(checkpoint_key)
175+
assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1"
176+
177+
178+
def test_pin_thread_use_case(redis_url: str) -> None:
179+
"""Test the 'pin thread' use case from issue #66.
180+
181+
This simulates pinning a specific thread by removing its TTL,
182+
making it persistent while other threads expire.
183+
"""
184+
saver = RedisSaver(
185+
redis_url, ttl={"default_ttl": 0.1}
186+
) # 6 seconds TTL for quick test
187+
saver.setup()
188+
189+
# Create two threads
190+
thread_to_pin = str(uuid4())
191+
thread_to_expire = str(uuid4())
192+
193+
# Store checkpoint IDs to avoid using wildcards (more efficient and precise)
194+
checkpoint_ids = {}
195+
196+
for thread_id in [thread_to_pin, thread_to_expire]:
197+
checkpoint = create_checkpoint(
198+
checkpoint=empty_checkpoint(),
199+
channels={"messages": [f"Thread {thread_id}"]},
200+
step=1,
201+
)
202+
checkpoint["channel_values"]["messages"] = [f"Thread {thread_id}"]
203+
204+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
205+
206+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
207+
checkpoint_ids[thread_id] = saved_config["configurable"]["checkpoint_id"]
208+
209+
# Pin the first thread by removing its TTL using exact key
210+
pinned_checkpoint_key = (
211+
f"checkpoint:{thread_to_pin}:__empty__:{checkpoint_ids[thread_to_pin]}"
212+
)
213+
saver._apply_ttl_to_keys(pinned_checkpoint_key, ttl_minutes=-1)
214+
215+
# Verify pinned thread has no TTL
216+
assert saver._redis.exists(pinned_checkpoint_key) == 1
217+
assert saver._redis.ttl(pinned_checkpoint_key) == -1
218+
219+
# Verify other thread still has TTL
220+
expiring_checkpoint_key = (
221+
f"checkpoint:{thread_to_expire}:__empty__:{checkpoint_ids[thread_to_expire]}"
222+
)
223+
assert saver._redis.exists(expiring_checkpoint_key) == 1
224+
ttl = saver._redis.ttl(expiring_checkpoint_key)
225+
assert 0 < ttl <= 6
226+
227+
# Wait for expiring thread to expire
228+
time.sleep(7)
229+
230+
# Pinned thread should still exist
231+
assert saver._redis.exists(pinned_checkpoint_key) == 1
232+
233+
# Expiring thread should be gone
234+
assert saver._redis.exists(expiring_checkpoint_key) == 0

0 commit comments

Comments
 (0)