Skip to content

Commit 4736ba0

Browse files
committed
feat: add TTL removal support for pinning checkpoints (#66)
Add support for removing TTL from Redis checkpoints to make them persistent. This enables "pinning" specific threads that should never expire while allowing others to be cleaned up automatically. Changes: - Add support for `ttl_minutes=-1` parameter to trigger Redis PERSIST command - Implement TTL removal in both sync and async checkpoint savers - Apply PERSIST to main key and all related keys (blobs, writes) - Add comprehensive test coverage for TTL removal functionality - Update README with documentation for the pinning feature
1 parent cf6a202 commit 4736ba0

File tree

4 files changed

+324
-3
lines changed

4 files changed

+324
-3
lines changed

README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,29 @@ with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as ch
193193
# Use the checkpointer...
194194
```
195195

196-
This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up.
196+
### Removing TTL (Pinning Threads)
197+
198+
You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire:
199+
200+
```python
201+
from langgraph.checkpoint.redis import RedisSaver
202+
203+
# Create saver with default TTL
204+
saver = RedisSaver.from_conn_string("redis://localhost:6379", ttl={"default_ttl": 60})
205+
saver.setup()
206+
207+
# ... save some checkpoints ...
208+
209+
# Remove TTL from a specific checkpoint to make it persistent
210+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{checkpoint_id}"
211+
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
212+
213+
# The checkpoint is now persistent and won't expire
214+
```
215+
216+
When no TTL configuration is provided, checkpoints are persistent by default (no expiration).
217+
218+
This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up while keeping important data persistent.
197219

198220
## Redis Stores
199221

@@ -314,11 +336,13 @@ For Redis Stores with vector search:
314336

315337
Both Redis checkpoint savers and stores leverage Redis's native key expiration:
316338

317-
- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command
339+
- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command for setting TTL
340+
- **TTL Removal**: Uses Redis's `PERSIST` command to remove TTL (with `ttl_minutes=-1`)
318341
- **Automatic Cleanup**: Redis automatically removes expired keys
319342
- **Configurable Default TTL**: Set a default TTL for all keys in minutes
320343
- **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed
321344
- **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes)
345+
- **Persistent by Default**: When no TTL is configured, keys are persistent (no expiration)
322346

323347
## Contributing
324348

langgraph/checkpoint/redis/aio.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ async def _apply_ttl_to_keys(
295295
main_key: The primary Redis key
296296
related_keys: Additional Redis keys that should expire at the same time
297297
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
298+
Use -1 to remove TTL (make keys persistent)
298299
299300
Returns:
300301
Result of the Redis operation
@@ -305,6 +306,32 @@ async def _apply_ttl_to_keys(
305306
ttl_minutes = self.ttl_config.get("default_ttl")
306307

307308
if ttl_minutes is not None:
309+
# Special case: -1 means remove TTL (make persistent)
310+
if ttl_minutes == -1:
311+
if self.cluster_mode:
312+
# For cluster mode, execute PERSIST operations individually
313+
await self._redis.persist(main_key)
314+
315+
if related_keys:
316+
for key in related_keys:
317+
await self._redis.persist(key)
318+
319+
return True
320+
else:
321+
# For non-cluster mode, use pipeline for efficiency
322+
pipeline = self._redis.pipeline()
323+
324+
# Remove TTL for main key
325+
pipeline.persist(main_key)
326+
327+
# Remove TTL for related keys
328+
if related_keys:
329+
for key in related_keys:
330+
pipeline.persist(key)
331+
332+
return await pipeline.execute()
333+
334+
# Regular TTL setting
308335
ttl_seconds = int(ttl_minutes * 60)
309336

310337
if self.cluster_mode:
@@ -709,7 +736,9 @@ async def alist(
709736
if isinstance(checkpoint_data, dict)
710737
else orjson.loads(checkpoint_data)
711738
)
712-
channel_values = self._recursive_deserialize(checkpoint_dict.get("channel_values", {}))
739+
channel_values = self._recursive_deserialize(
740+
checkpoint_dict.get("channel_values", {})
741+
)
713742
else:
714743
# If checkpoint data is missing, the document is corrupted
715744
# Set empty channel values rather than attempting a fallback

langgraph/checkpoint/redis/base.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _apply_ttl_to_keys(
238238
main_key: The primary Redis key
239239
related_keys: Additional Redis keys that should expire at the same time
240240
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
241+
Use -1 to remove TTL (make keys persistent)
241242
242243
Returns:
243244
Result of the Redis operation
@@ -248,6 +249,35 @@ def _apply_ttl_to_keys(
248249
ttl_minutes = self.ttl_config.get("default_ttl")
249250

250251
if ttl_minutes is not None:
252+
# Special case: -1 means remove TTL (make persistent)
253+
if ttl_minutes == -1:
254+
# Check if cluster mode is detected (for sync checkpoint savers)
255+
cluster_mode = getattr(self, "cluster_mode", False)
256+
257+
if cluster_mode:
258+
# For cluster mode, execute PERSIST operations individually
259+
self._redis.persist(main_key)
260+
261+
if related_keys:
262+
for key in related_keys:
263+
self._redis.persist(key)
264+
265+
return True
266+
else:
267+
# For non-cluster mode, use pipeline for efficiency
268+
pipeline = self._redis.pipeline()
269+
270+
# Remove TTL for main key
271+
pipeline.persist(main_key)
272+
273+
# Remove TTL for related keys
274+
if related_keys:
275+
for key in related_keys:
276+
pipeline.persist(key)
277+
278+
return pipeline.execute()
279+
280+
# Regular TTL setting
251281
ttl_seconds = int(ttl_minutes * 60)
252282

253283
# Check if cluster mode is detected (for sync checkpoint savers)

tests/test_ttl_removal.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""Tests for TTL removal feature (issue #66)."""
2+
3+
import time
4+
from uuid import uuid4
5+
6+
import pytest
7+
from langgraph.checkpoint.base import create_checkpoint, empty_checkpoint
8+
9+
from langgraph.checkpoint.redis import AsyncRedisSaver, RedisSaver
10+
11+
12+
def test_ttl_removal_with_negative_one(redis_url: str) -> None:
13+
"""Test that ttl_minutes=-1 removes TTL from keys."""
14+
saver = RedisSaver(redis_url, ttl={"default_ttl": 1}) # 1 minute default TTL
15+
saver.setup()
16+
17+
thread_id = str(uuid4())
18+
checkpoint = create_checkpoint(
19+
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
20+
)
21+
checkpoint["channel_values"]["messages"] = ["test"]
22+
23+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
24+
25+
# Save checkpoint (will have TTL)
26+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
27+
28+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
29+
30+
# Verify TTL is set
31+
ttl = saver._redis.ttl(checkpoint_key)
32+
assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}"
33+
34+
# Remove TTL using -1
35+
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
36+
37+
# Verify TTL is removed
38+
ttl_after = saver._redis.ttl(checkpoint_key)
39+
assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1"
40+
41+
42+
def test_ttl_removal_with_related_keys(redis_url: str) -> None:
43+
"""Test that TTL removal works for main key and related keys."""
44+
saver = RedisSaver(redis_url, ttl={"default_ttl": 1})
45+
saver.setup()
46+
47+
thread_id = str(uuid4())
48+
49+
# Create a checkpoint with writes (to have related keys)
50+
checkpoint = create_checkpoint(
51+
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
52+
)
53+
checkpoint["channel_values"]["messages"] = ["test"]
54+
55+
config = {
56+
"configurable": {
57+
"thread_id": thread_id,
58+
"checkpoint_ns": "",
59+
"checkpoint_id": "test-checkpoint",
60+
}
61+
}
62+
63+
# Save checkpoint and writes
64+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
65+
saver.put_writes(
66+
saved_config, [("channel1", "value1"), ("channel2", "value2")], "task-1"
67+
)
68+
69+
# Get the keys
70+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
71+
write_key1 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:0"
72+
write_key2 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:1"
73+
74+
# All keys should have TTL
75+
assert 50 <= saver._redis.ttl(checkpoint_key) <= 60
76+
assert 50 <= saver._redis.ttl(write_key1) <= 60
77+
assert 50 <= saver._redis.ttl(write_key2) <= 60
78+
79+
# Remove TTL from all keys
80+
saver._apply_ttl_to_keys(checkpoint_key, [write_key1, write_key2], ttl_minutes=-1)
81+
82+
# All keys should be persistent
83+
assert saver._redis.ttl(checkpoint_key) == -1
84+
assert saver._redis.ttl(write_key1) == -1
85+
assert saver._redis.ttl(write_key2) == -1
86+
87+
88+
def test_no_ttl_means_persistent(redis_url: str) -> None:
89+
"""Test that no TTL configuration means keys are persistent."""
90+
# Create saver with no TTL config
91+
saver = RedisSaver(redis_url) # No TTL config
92+
saver.setup()
93+
94+
thread_id = str(uuid4())
95+
checkpoint = create_checkpoint(
96+
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
97+
)
98+
checkpoint["channel_values"]["messages"] = ["test"]
99+
100+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
101+
102+
# Save checkpoint
103+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
104+
105+
# Check TTL
106+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
107+
ttl = saver._redis.ttl(checkpoint_key)
108+
109+
# Should be -1 (persistent) when no TTL configured
110+
assert ttl == -1, "Key should be persistent when no TTL configured"
111+
112+
113+
def test_ttl_removal_preserves_data(redis_url: str) -> None:
114+
"""Test that removing TTL doesn't affect the data."""
115+
saver = RedisSaver(redis_url, ttl={"default_ttl": 1})
116+
saver.setup()
117+
118+
thread_id = str(uuid4())
119+
checkpoint = create_checkpoint(
120+
checkpoint=empty_checkpoint(), channels={"messages": ["original data"]}, step=1
121+
)
122+
checkpoint["channel_values"]["messages"] = ["original data"]
123+
124+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
125+
126+
# Save checkpoint
127+
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
128+
129+
# Load data before TTL removal
130+
loaded_before = saver.get_tuple(saved_config)
131+
assert loaded_before.checkpoint["channel_values"]["messages"] == ["original data"]
132+
133+
# Remove TTL
134+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
135+
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
136+
137+
# Load data after TTL removal
138+
loaded_after = saver.get_tuple(saved_config)
139+
assert loaded_after.checkpoint["channel_values"]["messages"] == ["original data"]
140+
141+
# Verify TTL is removed
142+
assert saver._redis.ttl(checkpoint_key) == -1
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_async_ttl_removal(redis_url: str) -> None:
147+
"""Test TTL removal with async saver."""
148+
async with AsyncRedisSaver.from_conn_string(
149+
redis_url, ttl={"default_ttl": 1}
150+
) as saver:
151+
thread_id = str(uuid4())
152+
checkpoint = create_checkpoint(
153+
checkpoint=empty_checkpoint(), channels={"messages": ["async test"]}, step=1
154+
)
155+
checkpoint["channel_values"]["messages"] = ["async test"]
156+
157+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
158+
159+
# Save checkpoint
160+
saved_config = await saver.aput(
161+
config, checkpoint, {"source": "test", "step": 1}, {}
162+
)
163+
164+
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
165+
166+
# Verify TTL is set
167+
ttl = await saver._redis.ttl(checkpoint_key)
168+
assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}"
169+
170+
# Remove TTL using -1
171+
await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
172+
173+
# Verify TTL is removed
174+
ttl_after = await saver._redis.ttl(checkpoint_key)
175+
assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1"
176+
177+
178+
def test_pin_thread_use_case(redis_url: str) -> None:
179+
"""Test the 'pin thread' use case from issue #66.
180+
181+
This simulates pinning a specific thread by removing its TTL,
182+
making it persistent while other threads expire.
183+
"""
184+
saver = RedisSaver(
185+
redis_url, ttl={"default_ttl": 0.1}
186+
) # 6 seconds TTL for quick test
187+
saver.setup()
188+
189+
# Create two threads
190+
thread_to_pin = str(uuid4())
191+
thread_to_expire = str(uuid4())
192+
193+
for thread_id in [thread_to_pin, thread_to_expire]:
194+
checkpoint = create_checkpoint(
195+
checkpoint=empty_checkpoint(),
196+
channels={"messages": [f"Thread {thread_id}"]},
197+
step=1,
198+
)
199+
checkpoint["channel_values"]["messages"] = [f"Thread {thread_id}"]
200+
201+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
202+
203+
saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
204+
205+
# Pin the first thread by removing its TTL
206+
pinned_checkpoint_key = f"checkpoint:{thread_to_pin}:__empty__:*"
207+
# Get actual key
208+
keys = list(saver._redis.keys(pinned_checkpoint_key))
209+
if keys:
210+
saver._apply_ttl_to_keys(keys[0], ttl_minutes=-1)
211+
212+
# Verify pinned thread has no TTL
213+
pinned_keys = list(saver._redis.keys(f"checkpoint:{thread_to_pin}:__empty__:*"))
214+
assert len(pinned_keys) > 0
215+
assert saver._redis.ttl(pinned_keys[0]) == -1
216+
217+
# Verify other thread still has TTL
218+
expiring_keys = list(
219+
saver._redis.keys(f"checkpoint:{thread_to_expire}:__empty__:*")
220+
)
221+
assert len(expiring_keys) > 0
222+
ttl = saver._redis.ttl(expiring_keys[0])
223+
assert 0 < ttl <= 6
224+
225+
# Wait for expiring thread to expire
226+
time.sleep(7)
227+
228+
# Pinned thread should still exist
229+
pinned_keys_after = list(
230+
saver._redis.keys(f"checkpoint:{thread_to_pin}:__empty__:*")
231+
)
232+
assert len(pinned_keys_after) > 0
233+
234+
# Expiring thread should be gone
235+
expiring_keys_after = list(
236+
saver._redis.keys(f"checkpoint:{thread_to_expire}:__empty__:*")
237+
)
238+
assert len(expiring_keys_after) == 0

0 commit comments

Comments
 (0)