1818 PutOp ,
1919 Result ,
2020 SearchOp ,
21+ TTLConfig ,
2122)
2223from redis import Redis
2324from redis .commands .search .query import Query
@@ -70,14 +71,19 @@ class RedisStore(BaseStore, BaseRedisStore[Redis, SearchIndex]):
7071 vector similarity search support.
7172 """
7273
74+ # Enable TTL support
75+ supports_ttl = True
76+ ttl_config : Optional [TTLConfig ] = None
77+
7378 def __init__ (
7479 self ,
7580 conn : Redis ,
7681 * ,
7782 index : Optional [IndexConfig ] = None ,
83+ ttl : Optional [dict [str , Any ]] = None ,
7884 ) -> None :
7985 BaseStore .__init__ (self )
80- BaseRedisStore .__init__ (self , conn , index = index )
86+ BaseRedisStore .__init__ (self , conn , index = index , ttl = ttl )
8187
8288 @classmethod
8389 @contextmanager
@@ -86,12 +92,13 @@ def from_conn_string(
8692 conn_string : str ,
8793 * ,
8894 index : Optional [IndexConfig ] = None ,
95+ ttl : Optional [dict [str , Any ]] = None ,
8996 ) -> Iterator [RedisStore ]:
9097 """Create store from Redis connection string."""
9198 client = None
9299 try :
93100 client = RedisConnectionFactory .get_redis_connection (conn_string )
94- yield cls (client , index = index )
101+ yield cls (client , index = index , ttl = ttl )
95102 finally :
96103 if client :
97104 client .close ()
@@ -186,15 +193,64 @@ def _batch_get_ops(
186193 results : list [Result ],
187194 ) -> None :
188195 """Execute GET operations in batch."""
196+ refresh_keys_by_idx : dict [int , list [str ]] = (
197+ {}
198+ ) # Track keys that need TTL refreshed by op index
199+
189200 for query , _ , namespace , items in self ._get_batch_GET_ops_queries (get_ops ):
190201 res = self .store_index .search (Query (query ))
191202 # Parse JSON from each document
192203 key_to_row = {
193- json .loads (doc .json )["key" ]: json .loads (doc .json ) for doc in res .docs
204+ json .loads (doc .json )["key" ]: (json .loads (doc .json ), doc .id )
205+ for doc in res .docs
194206 }
207+
195208 for idx , key in items :
196209 if key in key_to_row :
197- results [idx ] = _row_to_item (namespace , key_to_row [key ])
210+ data , doc_id = key_to_row [key ]
211+ results [idx ] = _row_to_item (namespace , data )
212+
213+ # Find the corresponding operation by looking it up in the operation list
214+ # This is needed because idx is the index in the overall operation list
215+ op_idx = None
216+ for i , (local_idx , op ) in enumerate (get_ops ):
217+ if local_idx == idx :
218+ op_idx = i
219+ break
220+
221+ if op_idx is not None :
222+ op = get_ops [op_idx ][1 ]
223+ if hasattr (op , "refresh_ttl" ) and op .refresh_ttl :
224+ if idx not in refresh_keys_by_idx :
225+ refresh_keys_by_idx [idx ] = []
226+ refresh_keys_by_idx [idx ].append (doc_id )
227+
228+ # Also add vector keys for the same document
229+ doc_uuid = doc_id .split (":" )[- 1 ]
230+ vector_key = (
231+ f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_uuid } "
232+ )
233+ refresh_keys_by_idx [idx ].append (vector_key )
234+
235+ # Now refresh TTLs for any keys that need it
236+ if refresh_keys_by_idx and self .ttl_config :
237+ # Get default TTL from config
238+ ttl_minutes = None
239+ if "default_ttl" in self .ttl_config :
240+ ttl_minutes = self .ttl_config .get ("default_ttl" )
241+
242+ if ttl_minutes is not None :
243+ ttl_seconds = int (ttl_minutes * 60 )
244+ pipeline = self ._redis .pipeline ()
245+
246+ for keys in refresh_keys_by_idx .values ():
247+ for key in keys :
248+ # Only refresh TTL if the key exists and has a TTL
249+ ttl = self ._redis .ttl (key )
250+ if ttl > 0 : # Only refresh if key exists and has TTL
251+ pipeline .expire (key , ttl_seconds )
252+
253+ pipeline .execute ()
198254
199255 def _batch_put_ops (
200256 self ,
@@ -219,20 +275,35 @@ def _batch_put_ops(
219275 doc_ids : dict [tuple [str , str ], str ] = {}
220276 store_docs : list [RedisDocument ] = []
221277 store_keys : list [str ] = []
278+ ttl_tracking : dict [str , tuple [list [str ], Optional [float ]]] = (
279+ {}
280+ ) # Tracks keys that need TTL + their TTL values
222281
223282 # Generate IDs for PUT operations
224283 for _ , op in put_ops :
225284 if op .value is not None :
226285 generated_doc_id = str (ULID ())
227286 namespace = _namespace_to_text (op .namespace )
228287 doc_ids [(namespace , op .key )] = generated_doc_id
288+ # Track TTL for this document if specified
289+ if hasattr (op , "ttl" ) and op .ttl is not None :
290+ main_key = f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { generated_doc_id } "
291+ ttl_tracking [main_key ] = ([], op .ttl )
229292
230293 # Load store docs with explicit keys
231294 for doc in operations :
232295 store_key = (doc ["prefix" ], doc ["key" ])
233296 doc_id = doc_ids [store_key ]
297+ # Remove TTL fields - they're not needed with Redis native TTL
298+ if "ttl_minutes" in doc :
299+ doc .pop ("ttl_minutes" , None )
300+ if "expires_at" in doc :
301+ doc .pop ("expires_at" , None )
302+
234303 store_docs .append (doc )
235- store_keys .append (f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } " )
304+ redis_key = f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
305+ store_keys .append (redis_key )
306+
236307 if store_docs :
237308 self .store_index .load (store_docs , keys = store_keys )
238309
@@ -260,12 +331,21 @@ def _batch_put_ops(
260331 "updated_at" : datetime .now (timezone .utc ).timestamp (),
261332 }
262333 )
263- vector_keys .append (
264- f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
265- )
334+ vector_key = f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
335+ vector_keys .append (vector_key )
336+
337+ # Add this vector key to the related keys list for TTL
338+ main_key = f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
339+ if main_key in ttl_tracking :
340+ ttl_tracking [main_key ][0 ].append (vector_key )
341+
266342 if vector_docs :
267343 self .vector_index .load (vector_docs , keys = vector_keys )
268344
345+ # Now apply TTLs after all documents are loaded
346+ for main_key , (related_keys , ttl_minutes ) in ttl_tracking .items ():
347+ self ._apply_ttl_to_keys (main_key , related_keys , ttl_minutes )
348+
269349 def _batch_search_ops (
270350 self ,
271351 search_ops : list [tuple [int , SearchOp ]],
@@ -316,6 +396,8 @@ def _batch_search_ops(
316396
317397 # Process results maintaining order and applying filters
318398 items = []
399+ refresh_keys = [] # Track keys that need TTL refreshed
400+
319401 for store_key , store_doc in zip (result_map .keys (), store_docs ):
320402 if store_doc :
321403 vector_result = result_map [store_key ]
@@ -345,6 +427,16 @@ def _batch_search_ops(
345427 if not matches :
346428 continue
347429
430+ # If refresh_ttl is true, add to list for refreshing
431+ if op .refresh_ttl :
432+ refresh_keys .append (store_key )
433+ # Also find associated vector keys with same ID
434+ doc_id = store_key .split (":" )[- 1 ]
435+ vector_key = (
436+ f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
437+ )
438+ refresh_keys .append (vector_key )
439+
348440 items .append (
349441 _row_to_search_item (
350442 _decode_ns (store_doc ["prefix" ]),
@@ -353,13 +445,31 @@ def _batch_search_ops(
353445 )
354446 )
355447
448+ # Refresh TTL if requested
449+ if op .refresh_ttl and refresh_keys and self .ttl_config :
450+ # Get default TTL from config
451+ ttl_minutes = None
452+ if "default_ttl" in self .ttl_config :
453+ ttl_minutes = self .ttl_config .get ("default_ttl" )
454+
455+ if ttl_minutes is not None :
456+ ttl_seconds = int (ttl_minutes * 60 )
457+ pipeline = self ._redis .pipeline ()
458+ for key in refresh_keys :
459+ # Only refresh TTL if the key exists and has a TTL
460+ ttl = self ._redis .ttl (key )
461+ if ttl > 0 : # Only refresh if key exists and has TTL
462+ pipeline .expire (key , ttl_seconds )
463+ pipeline .execute ()
464+
356465 results [idx ] = items
357466 else :
358467 # Regular search
359468 query = Query (query_str )
360469 # Get all potential matches for filtering
361470 res = self .store_index .search (query )
362471 items = []
472+ refresh_keys = [] # Track keys that need TTL refreshed
363473
364474 for doc in res .docs :
365475 data = json .loads (doc .json )
@@ -378,13 +488,41 @@ def _batch_search_ops(
378488 break
379489 if not matches :
380490 continue
491+
492+ # If refresh_ttl is true, add the key to refresh list
493+ if op .refresh_ttl :
494+ refresh_keys .append (doc .id )
495+ # Also find associated vector keys with same ID
496+ doc_id = doc .id .split (":" )[- 1 ]
497+ vector_key = (
498+ f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
499+ )
500+ refresh_keys .append (vector_key )
501+
381502 items .append (_row_to_search_item (_decode_ns (data ["prefix" ]), data ))
382503
383504 # Apply pagination after filtering
384505 if params :
385506 limit , offset = params
386507 items = items [offset : offset + limit ]
387508
509+ # Refresh TTL if requested
510+ if op .refresh_ttl and refresh_keys and self .ttl_config :
511+ # Get default TTL from config
512+ ttl_minutes = None
513+ if "default_ttl" in self .ttl_config :
514+ ttl_minutes = self .ttl_config .get ("default_ttl" )
515+
516+ if ttl_minutes is not None :
517+ ttl_seconds = int (ttl_minutes * 60 )
518+ pipeline = self ._redis .pipeline ()
519+ for key in refresh_keys :
520+ # Only refresh TTL if the key exists and has a TTL
521+ ttl = self ._redis .ttl (key )
522+ if ttl > 0 : # Only refresh if key exists and has TTL
523+ pipeline .expire (key , ttl_seconds )
524+ pipeline .execute ()
525+
388526 results [idx ] = items
389527
390528 async def abatch (self , ops : Iterable [Op ]) -> list [Result ]:
0 commit comments