Skip to content

Commit d2bff7f

Browse files
committed
Stop using aclose
1 parent 06c41b4 commit d2bff7f

File tree

5 files changed

+130
-32
lines changed

5 files changed

+130
-32
lines changed

langgraph/store/redis/__init__.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,12 @@ def _batch_put_ops(
293293
doc_ids[(namespace, op.key)] = generated_doc_id
294294
# Track TTL for this document if specified
295295
if hasattr(op, "ttl") and op.ttl is not None:
296-
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, generated_doc_id, self.cluster_mode)
296+
main_key = get_key_with_hash_tag(
297+
STORE_PREFIX,
298+
REDIS_KEY_SEPARATOR,
299+
generated_doc_id,
300+
self.cluster_mode,
301+
)
297302
ttl_tracking[main_key] = ([], op.ttl)
298303

299304
# Load store docs with explicit keys
@@ -307,7 +312,9 @@ def _batch_put_ops(
307312
doc.pop("expires_at", None)
308313

309314
store_docs.append(doc)
310-
redis_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
315+
redis_key = get_key_with_hash_tag(
316+
STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
317+
)
311318
store_keys.append(redis_key)
312319

313320
if store_docs:
@@ -337,11 +344,15 @@ def _batch_put_ops(
337344
"updated_at": datetime.now(timezone.utc).timestamp(),
338345
}
339346
)
340-
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
347+
vector_key = get_key_with_hash_tag(
348+
STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
349+
)
341350
vector_keys.append(vector_key)
342351

343352
# Add this vector key to the related keys list for TTL
344-
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
353+
main_key = get_key_with_hash_tag(
354+
STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
355+
)
345356
if main_key in ttl_tracking:
346357
ttl_tracking[main_key][0].append(vector_key)
347358

@@ -395,8 +406,13 @@ def _batch_search_ops(
395406
)
396407
if doc_id:
397408
# Convert vector:ID to store:ID
398-
doc_uuid = doc_id.split(':')[1]
399-
store_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_uuid, self.cluster_mode)
409+
doc_uuid = doc_id.split(":")[1]
410+
store_key = get_key_with_hash_tag(
411+
STORE_PREFIX,
412+
REDIS_KEY_SEPARATOR,
413+
doc_uuid,
414+
self.cluster_mode,
415+
)
400416
result_map[store_key] = doc
401417
pipe.json().get(store_key)
402418

@@ -441,7 +457,12 @@ def _batch_search_ops(
441457
refresh_keys.append(store_key)
442458
# Also find associated vector keys with same ID
443459
doc_id = store_key.split(":")[-1]
444-
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
460+
vector_key = get_key_with_hash_tag(
461+
STORE_VECTOR_PREFIX,
462+
REDIS_KEY_SEPARATOR,
463+
doc_id,
464+
self.cluster_mode,
465+
)
445466
refresh_keys.append(vector_key)
446467

447468
items.append(
@@ -462,7 +483,9 @@ def _batch_search_ops(
462483
if ttl_minutes is not None:
463484
ttl_seconds = int(ttl_minutes * 60)
464485
# In cluster mode, we must use transaction=False
465-
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
486+
pipeline = self._redis.pipeline(
487+
transaction=not self.cluster_mode
488+
)
466489
for key in refresh_keys:
467490
# Only refresh TTL if the key exists and has a TTL
468491
ttl = self._redis.ttl(key)
@@ -504,7 +527,12 @@ def _batch_search_ops(
504527
refresh_keys.append(doc.id)
505528
# Also find associated vector keys with same ID
506529
doc_id = doc.id.split(":")[-1]
507-
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
530+
vector_key = get_key_with_hash_tag(
531+
STORE_VECTOR_PREFIX,
532+
REDIS_KEY_SEPARATOR,
533+
doc_id,
534+
self.cluster_mode,
535+
)
508536
refresh_keys.append(vector_key)
509537

510538
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
@@ -521,7 +549,9 @@ def _batch_search_ops(
521549
if ttl_minutes is not None:
522550
ttl_seconds = int(ttl_minutes * 60)
523551
# In cluster mode, we must use transaction=False
524-
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
552+
pipeline = self._redis.pipeline(
553+
transaction=not self.cluster_mode
554+
)
525555
for key in refresh_keys:
526556
# Only refresh TTL if the key exists and has a TTL
527557
ttl = self._redis.ttl(key)

langgraph/store/redis/aio.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,12 @@ async def _batch_get_ops(
430430

431431
# Also add vector keys for the same document
432432
doc_uuid = doc_id.split(":")[-1]
433-
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_uuid, self.cluster_mode)
433+
vector_key = get_key_with_hash_tag(
434+
STORE_VECTOR_PREFIX,
435+
REDIS_KEY_SEPARATOR,
436+
doc_uuid,
437+
self.cluster_mode,
438+
)
434439
refresh_keys_by_idx[idx].append(vector_key)
435440

436441
# Now refresh TTLs for any keys that need it
@@ -574,7 +579,12 @@ async def _batch_put_ops(
574579
doc_ids[(namespace, op.key)] = generated_doc_id
575580
# Track TTL for this document if specified
576581
if hasattr(op, "ttl") and op.ttl is not None:
577-
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, generated_doc_id, self.cluster_mode)
582+
main_key = get_key_with_hash_tag(
583+
STORE_PREFIX,
584+
REDIS_KEY_SEPARATOR,
585+
generated_doc_id,
586+
self.cluster_mode,
587+
)
578588
ttl_tracking[main_key] = ([], op.ttl)
579589

580590
# Load store docs with explicit keys
@@ -588,7 +598,9 @@ async def _batch_put_ops(
588598
doc.pop("expires_at", None)
589599

590600
store_docs.append(doc)
591-
redis_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
601+
redis_key = get_key_with_hash_tag(
602+
STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
603+
)
592604
store_keys.append(redis_key)
593605

594606
if store_docs:
@@ -618,11 +630,15 @@ async def _batch_put_ops(
618630
"updated_at": datetime.now(timezone.utc).timestamp(),
619631
}
620632
)
621-
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
633+
vector_key = get_key_with_hash_tag(
634+
STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
635+
)
622636
vector_keys.append(vector_key)
623637

624638
# Add this vector key to the related keys list for TTL
625-
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
639+
main_key = get_key_with_hash_tag(
640+
STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
641+
)
626642
if main_key in ttl_tracking:
627643
ttl_tracking[main_key][0].append(vector_key)
628644

@@ -676,8 +692,13 @@ async def _batch_search_ops(
676692
)
677693
if doc_id:
678694
# Convert vector:ID to store:ID
679-
doc_uuid = doc_id.split(':')[1]
680-
store_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_uuid, self.cluster_mode)
695+
doc_uuid = doc_id.split(":")[1]
696+
store_key = get_key_with_hash_tag(
697+
STORE_PREFIX,
698+
REDIS_KEY_SEPARATOR,
699+
doc_uuid,
700+
self.cluster_mode,
701+
)
681702
result_map[store_key] = doc
682703
pipeline.json().get(store_key)
683704

langgraph/store/redis/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@
4141
STORE_PREFIX = "store"
4242
STORE_VECTOR_PREFIX = "store_vectors"
4343

44-
def get_key_with_hash_tag(prefix: str, separator: str, id_value: str, use_hash_tag: bool = False) -> str:
44+
45+
def get_key_with_hash_tag(
46+
prefix: str, separator: str, id_value: str, use_hash_tag: bool = False
47+
) -> str:
4548
"""Create a Redis key with optional hash tag for cluster mode.
4649
4750
In Redis Cluster, keys with hash tags ensure they're stored in the same hash slot.
@@ -53,6 +56,7 @@ def get_key_with_hash_tag(prefix: str, separator: str, id_value: str, use_hash_t
5356
else:
5457
return f"{prefix}{separator}{id_value}"
5558

59+
5660
# Schemas for Redis Search indices
5761
SCHEMAS = [
5862
{
@@ -209,10 +213,11 @@ def __init__(
209213

210214
# Detect if Redis client is a cluster client
211215
from redis.exceptions import ResponseError
216+
212217
try:
213218
# Try to run a cluster command
214219
# This will succeed for cluster clients and fail for non-cluster clients
215-
self._redis.cluster("info") # type: ignore
220+
self._redis.cluster("info")
216221
self.cluster_mode = True
217222
logger.info("Redis cluster mode detected")
218223
except (ResponseError, AttributeError):

tests/test_async_cluster_mode.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import pytest
76
from unittest.mock import AsyncMock, MagicMock, patch
7+
8+
import pytest
89
from redis.asyncio import Redis as AsyncRedis
910
from redis.exceptions import ResponseError
1011
from ulid import ULID
1112

1213
from langgraph.store.redis import AsyncRedisStore
13-
from langgraph.store.redis.base import get_key_with_hash_tag, STORE_PREFIX, REDIS_KEY_SEPARATOR, STORE_VECTOR_PREFIX
14+
from langgraph.store.redis.base import (
15+
REDIS_KEY_SEPARATOR,
16+
STORE_PREFIX,
17+
STORE_VECTOR_PREFIX,
18+
get_key_with_hash_tag,
19+
)
1420

1521

1622
class MockAsyncRedisCluster(AsyncRedis):
@@ -49,28 +55,34 @@ def pipeline(self, transaction=True):
4955
# Mock the pipeline's execute method
5056
async def execute():
5157
return []
58+
5259
mock_pipeline.execute = execute
5360

5461
# Mock the pipeline's expire method
5562
async def expire(key, ttl):
5663
self.expire_calls.append({"key": key, "ttl": ttl})
5764
return True
65+
5866
mock_pipeline.expire = expire
5967

6068
# Mock the pipeline's delete method
6169
async def delete(key):
6270
self.delete_calls.append({"key": key})
6371
return 1
72+
6473
mock_pipeline.delete = delete
6574

6675
# Mock the pipeline's json method
6776
def json():
6877
mock_json = AsyncMock()
78+
6979
async def get(key):
7080
self.json_get_calls.append({"key": key})
7181
return {"key": key.split(":")[-1], "value": {"test": "data"}}
82+
7283
mock_json.get = get
7384
return mock_json
85+
7486
mock_pipeline.json = json
7587

7688
return mock_pipeline
@@ -92,9 +104,11 @@ async def ttl(self, key):
92104
def json(self):
93105
"""Mock json method."""
94106
mock = AsyncMock()
107+
95108
async def get(key):
96109
self.json_get_calls.append({"key": key})
97110
return {"key": key.split(":")[-1], "value": {"test": "data"}}
111+
98112
mock.get = get
99113
return mock
100114

@@ -166,27 +180,35 @@ async def test_async_hash_tag_in_keys(async_cluster_store, mock_async_redis_clus
166180

167181

168182
@pytest.mark.asyncio
169-
async def test_async_pipeline_transaction_false(async_cluster_store, mock_async_redis_cluster):
183+
async def test_async_pipeline_transaction_false(
184+
async_cluster_store, mock_async_redis_cluster
185+
):
170186
"""Test that pipeline is created with transaction=False in cluster mode."""
171187
# Apply TTL to trigger pipeline creation
172188
await async_cluster_store._apply_ttl_to_keys("test_key", ["related_key"], 1.0)
173189

174190
# Check that pipeline was created with transaction=False
175191
assert len(mock_async_redis_cluster.pipeline_calls) > 0
176192
for call in mock_async_redis_cluster.pipeline_calls:
177-
assert call["transaction"] is False, "Pipeline should be created with transaction=False in cluster mode"
193+
assert (
194+
call["transaction"] is False
195+
), "Pipeline should be created with transaction=False in cluster mode"
178196

179197
# Put a value to trigger more pipeline usage
180198
await async_cluster_store.aput(("test",), "key1", {"data": "value1"})
181199

182200
# Check again
183201
assert len(mock_async_redis_cluster.pipeline_calls) > 0
184202
for call in mock_async_redis_cluster.pipeline_calls:
185-
assert call["transaction"] is False, "Pipeline should be created with transaction=False in cluster mode"
203+
assert (
204+
call["transaction"] is False
205+
), "Pipeline should be created with transaction=False in cluster mode"
186206

187207

188208
@pytest.mark.asyncio
189-
async def test_async_ttl_refresh_in_search(async_cluster_store, mock_async_redis_cluster):
209+
async def test_async_ttl_refresh_in_search(
210+
async_cluster_store, mock_async_redis_cluster
211+
):
190212
"""Test that TTL refresh in search uses transaction=False for pipeline in cluster mode."""
191213
# Clear the pipeline calls to start fresh
192214
mock_async_redis_cluster.pipeline_calls = []
@@ -211,7 +233,9 @@ async def test_async_ttl_refresh_in_search(async_cluster_store, mock_async_redis
211233
# Check that pipeline was created with transaction=False
212234
assert len(mock_async_redis_cluster.pipeline_calls) > 0
213235
for call in mock_async_redis_cluster.pipeline_calls:
214-
assert call["transaction"] is False, "Pipeline should be created with transaction=False in cluster mode"
236+
assert (
237+
call["transaction"] is False
238+
), "Pipeline should be created with transaction=False in cluster mode"
215239
finally:
216240
# Restore the original ttl method
217241
mock_async_redis_cluster.ttl = original_ttl

0 commit comments

Comments
 (0)