Skip to content

Commit 5a6ded6

Browse files
committed
Sync up implementations of ttl stuff; test fixes
1 parent 3205eb4 commit 5a6ded6

File tree

4 files changed

+158
-94
lines changed

4 files changed

+158
-94
lines changed

langgraph/store/redis/__init__.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -435,33 +435,26 @@ def _batch_search_ops(
435435
)
436436
vector_results = self.vector_index.query(vector_query)
437437

438-
# Get matching store docs: direct JSON GET for cluster, batch for non-cluster
438+
# Get matching store docs
439439
result_map = {} # Map store key to vector result with distances
440-
store_docs = []
441440

442441
if self.cluster_mode:
442+
store_docs = []
443443
# Direct JSON GET for cluster mode
444-
json_client = self._redis.json()
445-
# Monkey-patch json method to always return this instance for consistent call tracking
446-
try:
447-
self._redis.json = lambda: json_client
448-
except Exception:
449-
pass
450444
for doc in vector_results:
451445
doc_id = (
452446
doc.get("id")
453447
if isinstance(doc, dict)
454448
else getattr(doc, "id", None)
455449
)
456-
if not doc_id:
457-
continue
458-
doc_uuid = doc_id.split(":")[1]
459-
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
460-
result_map[store_key] = doc
461-
# Record JSON GET call for testing
462-
if hasattr(self._redis, "json_get_calls"):
463-
self._redis.json_get_calls.append({"key": store_key})
464-
store_docs.append(json_client.get(store_key))
450+
if doc_id:
451+
doc_uuid = doc_id.split(":")[1]
452+
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
453+
result_map[store_key] = doc
454+
# Fetch individually in cluster mode
455+
store_doc_item = self._redis.json().get(store_key)
456+
store_docs.append(store_doc_item)
457+
store_docs_raw = store_docs
465458
else:
466459
pipe = self._redis.pipeline(transaction=True)
467460
for doc in vector_results:
@@ -477,13 +470,15 @@ def _batch_search_ops(
477470
result_map[store_key] = doc
478471
pipe.json().get(store_key)
479472
# Execute all lookups in one batch
480-
store_docs = pipe.execute()
473+
store_docs_raw = pipe.execute()
481474

482475
# Process results maintaining order and applying filters
483476
items = []
484477
refresh_keys = [] # Track keys that need TTL refreshed
478+
store_docs_iter = iter(store_docs_raw)
485479

486-
for store_key, store_doc in zip(result_map.keys(), store_docs):
480+
for store_key in result_map.keys():
481+
store_doc = next(store_docs_iter, None)
487482
if store_doc:
488483
vector_result = result_map[store_key]
489484
# Get vector_distance from original search result
@@ -494,7 +489,25 @@ def _batch_search_ops(
494489
)
495490
# Convert to similarity score
496491
score = (1.0 - float(dist)) if dist is not None else 0.0
497-
store_doc["vector_distance"] = dist
492+
if not isinstance(store_doc, dict):
493+
try:
494+
store_doc = json.loads(
495+
store_doc
496+
) # Attempt to parse if it's a JSON string
497+
except (json.JSONDecodeError, TypeError):
498+
logger.error(f"Failed to parse store_doc: {store_doc}")
499+
continue # Skip this problematic document
500+
501+
if isinstance(
502+
store_doc, dict
503+
): # Check again after potential parsing
504+
store_doc["vector_distance"] = dist
505+
else:
506+
# if still not a dict, this means it's a problematic entry
507+
logger.error(
508+
f"store_doc is not a dict after parsing attempt: {store_doc}"
509+
)
510+
continue
498511

499512
# Apply value filters if needed
500513
if op.filter:
@@ -542,14 +555,16 @@ def _batch_search_ops(
542555
if self.cluster_mode:
543556
for key in refresh_keys:
544557
ttl = self._redis.ttl(key)
545-
if ttl > 0:
558+
if ttl > 0: # type: ignore
546559
self._redis.expire(key, ttl_seconds)
547560
else:
548561
pipeline = self._redis.pipeline(transaction=True)
549562
for key in refresh_keys:
550563
# Only refresh TTL if the key exists and has a TTL
551564
ttl = self._redis.ttl(key)
552-
if ttl > 0: # Only refresh if key exists and has TTL
565+
if (
566+
ttl > 0
567+
): # Only refresh if key exists and has TTL # type: ignore
553568
pipeline.expire(key, ttl_seconds)
554569
if pipeline.command_stack:
555570
pipeline.execute()
@@ -595,8 +610,6 @@ def _batch_search_ops(
595610

596611
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
597612

598-
# Note: Pagination is now handled by Redis, no need to slice items manually
599-
600613
# Refresh TTL if requested
601614
if op.refresh_ttl and refresh_keys and self.ttl_config:
602615
# Get default TTL from config
@@ -609,14 +622,16 @@ def _batch_search_ops(
609622
if self.cluster_mode:
610623
for key in refresh_keys:
611624
ttl = self._redis.ttl(key)
612-
if ttl > 0:
625+
if ttl > 0: # type: ignore
613626
self._redis.expire(key, ttl_seconds)
614627
else:
615628
pipeline = self._redis.pipeline(transaction=True)
616629
for key in refresh_keys:
617630
# Only refresh TTL if the key exists and has a TTL
618631
ttl = self._redis.ttl(key)
619-
if ttl > 0: # Only refresh if key exists and has TTL
632+
if (
633+
ttl > 0
634+
): # Only refresh if key exists and has TTL # type: ignore
620635
pipeline.expire(key, ttl_seconds)
621636
if pipeline.command_stack:
622637
pipeline.execute()

langgraph/store/redis/aio.py

Lines changed: 98 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -244,19 +244,12 @@ async def _apply_ttl_to_keys(
244244
pipeline = self._redis.pipeline(transaction=True)
245245

246246
# Set TTL for main key
247-
# Use MagicMock in tests to avoid coroutine warning
248-
expire_result = pipeline.expire(main_key, ttl_seconds)
249-
# If expire returns a coroutine (in tests), await it
250-
if hasattr(expire_result, "__await__"):
251-
await expire_result
247+
pipeline.expire(main_key, ttl_seconds)
252248

253249
# Set TTL for related keys
254250
if related_keys: # Check if related_keys is not None
255251
for key in related_keys:
256-
expire_result = pipeline.expire(key, ttl_seconds)
257-
# If expire returns a coroutine (in tests), await it
258-
if hasattr(expire_result, "__await__"):
259-
await expire_result
252+
pipeline.expire(key, ttl_seconds)
260253

261254
await pipeline.execute()
262255

@@ -739,50 +732,43 @@ async def _batch_search_ops(
739732
result_map = {}
740733

741734
if self.cluster_mode:
742-
store_docs_list = []
743-
for (
744-
doc_vr
745-
) in (
746-
vector_results_docs
747-
): # doc_vr is now an individual doc from the list
748-
doc_id_vr = (
749-
doc_vr.get("id")
750-
if isinstance(doc_vr, dict)
751-
else getattr(doc_vr, "id", None)
735+
store_docs = []
736+
for doc in vector_results_docs:
737+
doc_id = (
738+
doc.get("id")
739+
if isinstance(doc, dict)
740+
else getattr(doc, "id", None)
752741
)
753-
if doc_id_vr:
754-
doc_uuid_vr = doc_id_vr.split(":")[1]
755-
store_key_vr = (
756-
f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid_vr}"
757-
)
758-
result_map[store_key_vr] = doc_vr
742+
if doc_id:
743+
doc_uuid = doc_id.split(":")[1]
744+
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
745+
result_map[store_key] = doc
759746
# Fetch individually in cluster mode
760-
store_doc_item = await self._redis.json().get(store_key_vr)
761-
store_docs_list.append(store_doc_item)
762-
store_docs_raw = store_docs_list
747+
store_doc_item = await self._redis.json().get(store_key)
748+
store_docs.append(store_doc_item)
749+
store_docs_raw = store_docs
763750
else:
764751
pipeline = self._redis.pipeline(transaction=False)
765752
for (
766-
doc_vr
753+
doc
767754
) in (
768755
vector_results_docs
769756
): # doc_vr is now an individual doc from the list
770-
doc_id_vr = (
771-
doc_vr.get("id")
772-
if isinstance(doc_vr, dict)
773-
else getattr(doc_vr, "id", None)
757+
doc_id = (
758+
doc.get("id")
759+
if isinstance(doc, dict)
760+
else getattr(doc, "id", None)
774761
)
775-
if doc_id_vr:
776-
doc_uuid_vr = doc_id_vr.split(":")[1]
777-
store_key_vr = (
778-
f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid_vr}"
779-
)
780-
result_map[store_key_vr] = doc_vr
781-
pipeline.json().get(store_key_vr)
762+
if doc_id:
763+
doc_uuid = doc_id.split(":")[1]
764+
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
765+
result_map[store_key] = doc
766+
pipeline.json().get(store_key)
782767
store_docs_raw = await pipeline.execute()
783768

784769
# Process results maintaining order and applying filters
785770
items = []
771+
refresh_keys = [] # Track keys that need TTL refreshed
786772
store_docs_iter = iter(store_docs_raw)
787773

788774
for store_key in result_map.keys():
@@ -834,13 +820,48 @@ async def _batch_search_ops(
834820
if not matches:
835821
continue
836822

823+
# If refresh_ttl is true, add to list for refreshing
824+
if op.refresh_ttl:
825+
refresh_keys.append(store_key)
826+
# Also find associated vector keys with same ID
827+
doc_id = store_key.split(":")[-1]
828+
vector_key = (
829+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
830+
)
831+
refresh_keys.append(vector_key)
832+
837833
items.append(
838834
_row_to_search_item(
839835
_decode_ns(store_doc["prefix"]),
840836
store_doc,
841837
score=score,
842838
)
843839
)
840+
841+
# Refresh TTL if requested
842+
if op.refresh_ttl and refresh_keys and self.ttl_config:
843+
# Get default TTL from config
844+
ttl_minutes = None
845+
if "default_ttl" in self.ttl_config:
846+
ttl_minutes = self.ttl_config.get("default_ttl")
847+
848+
if ttl_minutes is not None:
849+
ttl_seconds = int(ttl_minutes * 60)
850+
if self.cluster_mode:
851+
for key in refresh_keys:
852+
ttl = await self._redis.ttl(key)
853+
if ttl > 0:
854+
await self._redis.expire(key, ttl_seconds)
855+
else:
856+
pipeline = self._redis.pipeline(transaction=True)
857+
for key in refresh_keys:
858+
# Only refresh TTL if the key exists and has a TTL
859+
ttl = await self._redis.ttl(key)
860+
if ttl > 0: # Only refresh if key exists and has TTL
861+
pipeline.expire(key, ttl_seconds)
862+
if pipeline.command_stack:
863+
await pipeline.execute()
864+
844865
results[idx] = items
845866

846867
else:
@@ -851,6 +872,7 @@ async def _batch_search_ops(
851872
# Execute search with limit and offset applied by Redis
852873
res = await self.store_index.search(query)
853874
items = []
875+
refresh_keys = [] # Track keys that need TTL refreshed
854876

855877
for doc in res.docs:
856878
data = json.loads(doc.json)
@@ -869,7 +891,43 @@ async def _batch_search_ops(
869891
break
870892
if not matches:
871893
continue
894+
895+
# If refresh_ttl is true, add the key to refresh list
896+
if op.refresh_ttl:
897+
refresh_keys.append(doc.id)
898+
# Also find associated vector keys with same ID
899+
doc_id = doc.id.split(":")[-1]
900+
vector_key = (
901+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
902+
)
903+
refresh_keys.append(vector_key)
904+
872905
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
906+
907+
# Refresh TTL if requested
908+
if op.refresh_ttl and refresh_keys and self.ttl_config:
909+
# Get default TTL from config
910+
ttl_minutes = None
911+
if "default_ttl" in self.ttl_config:
912+
ttl_minutes = self.ttl_config.get("default_ttl")
913+
914+
if ttl_minutes is not None:
915+
ttl_seconds = int(ttl_minutes * 60)
916+
if self.cluster_mode:
917+
for key in refresh_keys:
918+
ttl = await self._redis.ttl(key)
919+
if ttl > 0:
920+
await self._redis.expire(key, ttl_seconds)
921+
else:
922+
pipeline = self._redis.pipeline(transaction=True)
923+
for key in refresh_keys:
924+
# Only refresh TTL if the key exists and has a TTL
925+
ttl = await self._redis.ttl(key)
926+
if ttl > 0: # Only refresh if key exists and has TTL
927+
pipeline.expire(key, ttl_seconds)
928+
if pipeline.command_stack:
929+
await pipeline.execute()
930+
873931
results[idx] = items
874932

875933
async def _batch_list_namespaces_ops(

tests/test_async_cluster_mode.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,14 @@ def pipeline(self, transaction=True):
3939
# print(f"AsyncMockRedis.pipeline called with transaction={transaction}")
4040
self.pipeline_calls.append({"transaction": transaction})
4141
mock_pipeline = AsyncMock() # Use AsyncMock for awaitable methods
42+
mock_pipeline.expire = MagicMock(return_value=True)
43+
mock_pipeline.delete = MagicMock(return_value=1)
4244
mock_pipeline.execute = AsyncMock(return_value=[])
43-
mock_pipeline.expire = AsyncMock(return_value=True)
44-
mock_pipeline.delete = AsyncMock(return_value=1)
4545

4646
# Mock json().get() behavior within pipeline
4747
mock_json_pipeline = AsyncMock()
48-
mock_json_pipeline.get = AsyncMock(
49-
return_value={"key": "mock_key", "value": {"data": "mock_data"}}
50-
)
51-
mock_pipeline.json = MagicMock(
52-
return_value=mock_json_pipeline
53-
) # json() returns a mock that has async get
48+
mock_json_pipeline.get = MagicMock()
49+
mock_pipeline.json = MagicMock(return_value=mock_json_pipeline)
5450
return mock_pipeline
5551

5652
async def expire(self, key, ttl):
@@ -101,15 +97,13 @@ def __init__(self, *args, **kwargs):
10197
def pipeline(self, transaction=True):
10298
# print(f"AsyncMockRedisCluster.pipeline called with transaction={transaction}")
10399
self.pipeline_calls.append({"transaction": transaction})
104-
mock_pipeline = AsyncMock()
100+
mock_pipeline = MagicMock()
105101
mock_pipeline.execute = AsyncMock(return_value=[])
106-
mock_pipeline.expire = AsyncMock(return_value=True)
107-
mock_pipeline.delete = AsyncMock(return_value=1)
102+
mock_pipeline.expire = MagicMock(return_value=True)
103+
mock_pipeline.delete = MagicMock(return_value=1)
108104

109-
mock_json_pipeline = AsyncMock()
110-
mock_json_pipeline.get = AsyncMock(
111-
return_value={"key": "mock_key", "value": {"data": "mock_data"}}
112-
)
105+
mock_json_pipeline = MagicMock()
106+
mock_json_pipeline.get = MagicMock()
113107
mock_pipeline.json = MagicMock(return_value=mock_json_pipeline)
114108
return mock_pipeline
115109

@@ -166,9 +160,7 @@ async def test_async_cluster_mode_behavior_differs(
166160
):
167161
"""Test that AsyncRedisStore behavior differs for cluster vs. non-cluster clients."""
168162

169-
# --- Test with AsyncMockRedisCluster (simulates cluster) ---
170163
async_cluster_store = AsyncRedisStore(redis_client=mock_async_redis_cluster_client)
171-
# Mock indices for async_cluster_store
172164
mock_index_cluster = AsyncMock()
173165
mock_index_cluster.search = AsyncMock(return_value=MagicMock(docs=[]))
174166
mock_index_cluster.load = AsyncMock(return_value=None)

0 commit comments

Comments
 (0)