Skip to content

Commit 248f447

Browse files
committed
WIP on supporting cluster
1 parent 6fed6ad commit 248f447

File tree

5 files changed

+461
-27
lines changed

5 files changed

+461
-27
lines changed

langgraph/store/redis/__init__.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
_namespace_to_text,
4141
_row_to_item,
4242
_row_to_search_item,
43+
get_key_with_hash_tag,
4344
)
4445

4546
from .token_unescaper import TokenUnescaper
@@ -245,7 +246,8 @@ def _batch_get_ops(
245246

246247
if ttl_minutes is not None:
247248
ttl_seconds = int(ttl_minutes * 60)
248-
pipeline = self._redis.pipeline()
249+
# In cluster mode, we must use transaction=False
250+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
249251

250252
for keys in refresh_keys_by_idx.values():
251253
for key in keys:
@@ -291,7 +293,7 @@ def _batch_put_ops(
291293
doc_ids[(namespace, op.key)] = generated_doc_id
292294
# Track TTL for this document if specified
293295
if hasattr(op, "ttl") and op.ttl is not None:
294-
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{generated_doc_id}"
296+
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, generated_doc_id, self.cluster_mode)
295297
ttl_tracking[main_key] = ([], op.ttl)
296298

297299
# Load store docs with explicit keys
@@ -305,7 +307,7 @@ def _batch_put_ops(
305307
doc.pop("expires_at", None)
306308

307309
store_docs.append(doc)
308-
redis_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
310+
redis_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
309311
store_keys.append(redis_key)
310312

311313
if store_docs:
@@ -335,11 +337,11 @@ def _batch_put_ops(
335337
"updated_at": datetime.now(timezone.utc).timestamp(),
336338
}
337339
)
338-
vector_key = f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
340+
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
339341
vector_keys.append(vector_key)
340342

341343
# Add this vector key to the related keys list for TTL
342-
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
344+
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
343345
if main_key in ttl_tracking:
344346
ttl_tracking[main_key][0].append(vector_key)
345347

@@ -381,7 +383,8 @@ def _batch_search_ops(
381383
vector_results = self.vector_index.query(vector_query)
382384

383385
# Get matching store docs in pipeline
384-
pipe = self._redis.pipeline()
386+
# In cluster mode, we must use transaction=False
387+
pipe = self._redis.pipeline(transaction=not self.cluster_mode)
385388
result_map = {} # Map store key to vector result with distances
386389

387390
for doc in vector_results:
@@ -391,7 +394,9 @@ def _batch_search_ops(
391394
else getattr(doc, "id", None)
392395
)
393396
if doc_id:
394-
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id.split(':')[1]}" # Convert vector:ID to store:ID
397+
# Convert vector:ID to store:ID
398+
doc_uuid = doc_id.split(':')[1]
399+
store_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_uuid, self.cluster_mode)
395400
result_map[store_key] = doc
396401
pipe.json().get(store_key)
397402

@@ -436,9 +441,7 @@ def _batch_search_ops(
436441
refresh_keys.append(store_key)
437442
# Also find associated vector keys with same ID
438443
doc_id = store_key.split(":")[-1]
439-
vector_key = (
440-
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
441-
)
444+
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
442445
refresh_keys.append(vector_key)
443446

444447
items.append(
@@ -458,7 +461,8 @@ def _batch_search_ops(
458461

459462
if ttl_minutes is not None:
460463
ttl_seconds = int(ttl_minutes * 60)
461-
pipeline = self._redis.pipeline()
464+
# In cluster mode, we must use transaction=False
465+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
462466
for key in refresh_keys:
463467
# Only refresh TTL if the key exists and has a TTL
464468
ttl = self._redis.ttl(key)
@@ -500,9 +504,7 @@ def _batch_search_ops(
500504
refresh_keys.append(doc.id)
501505
# Also find associated vector keys with same ID
502506
doc_id = doc.id.split(":")[-1]
503-
vector_key = (
504-
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
505-
)
507+
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
506508
refresh_keys.append(vector_key)
507509

508510
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
@@ -518,7 +520,8 @@ def _batch_search_ops(
518520

519521
if ttl_minutes is not None:
520522
ttl_seconds = int(ttl_minutes * 60)
521-
pipeline = self._redis.pipeline()
523+
# In cluster mode, we must use transaction=False
524+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
522525
for key in refresh_keys:
523526
# Only refresh TTL if the key exists and has a TTL
524527
ttl = self._redis.ttl(key)

langgraph/store/redis/aio.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_namespace_to_text,
4545
_row_to_item,
4646
_row_to_search_item,
47+
get_key_with_hash_tag,
4748
)
4849

4950
from .token_unescaper import TokenUnescaper
@@ -209,7 +210,8 @@ async def _apply_ttl_to_keys(
209210

210211
if ttl_minutes is not None:
211212
ttl_seconds = int(ttl_minutes * 60)
212-
pipeline = self._redis.pipeline()
213+
# In cluster mode, we must use transaction=False
214+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
213215

214216
# Set TTL for main key
215217
await pipeline.expire(main_key, ttl_seconds)
@@ -428,9 +430,7 @@ async def _batch_get_ops(
428430

429431
# Also add vector keys for the same document
430432
doc_uuid = doc_id.split(":")[-1]
431-
vector_key = (
432-
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
433-
)
433+
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_uuid, self.cluster_mode)
434434
refresh_keys_by_idx[idx].append(vector_key)
435435

436436
# Now refresh TTLs for any keys that need it
@@ -442,7 +442,8 @@ async def _batch_get_ops(
442442

443443
if ttl_minutes is not None:
444444
ttl_seconds = int(ttl_minutes * 60)
445-
pipeline = self._redis.pipeline()
445+
# In cluster mode, we must use transaction=False
446+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
446447

447448
for keys in refresh_keys_by_idx.values():
448449
for key in keys:
@@ -544,7 +545,8 @@ async def _batch_put_ops(
544545
namespace = _namespace_to_text(op.namespace)
545546
query = f"@prefix:{namespace} @key:{{{_token_escaper.escape(op.key)}}}"
546547
results = await self.store_index.search(query)
547-
pipeline = self._redis.pipeline()
548+
# In cluster mode, we must use transaction=False
549+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
548550
for doc in results.docs:
549551
pipeline.delete(doc.id)
550552

@@ -572,7 +574,7 @@ async def _batch_put_ops(
572574
doc_ids[(namespace, op.key)] = generated_doc_id
573575
# Track TTL for this document if specified
574576
if hasattr(op, "ttl") and op.ttl is not None:
575-
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{generated_doc_id}"
577+
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, generated_doc_id, self.cluster_mode)
576578
ttl_tracking[main_key] = ([], op.ttl)
577579

578580
# Load store docs with explicit keys
@@ -586,7 +588,7 @@ async def _batch_put_ops(
586588
doc.pop("expires_at", None)
587589

588590
store_docs.append(doc)
589-
redis_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
591+
redis_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
590592
store_keys.append(redis_key)
591593

592594
if store_docs:
@@ -616,11 +618,11 @@ async def _batch_put_ops(
616618
"updated_at": datetime.now(timezone.utc).timestamp(),
617619
}
618620
)
619-
vector_key = f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
621+
vector_key = get_key_with_hash_tag(STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
620622
vector_keys.append(vector_key)
621623

622624
# Add this vector key to the related keys list for TTL
623-
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
625+
main_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode)
624626
if main_key in ttl_tracking:
625627
ttl_tracking[main_key][0].append(vector_key)
626628

@@ -673,7 +675,9 @@ async def _batch_search_ops(
673675
else getattr(doc, "id", None)
674676
)
675677
if doc_id:
676-
store_key = f"store:{doc_id.split(':')[1]}" # Convert vector:ID to store:ID
678+
# Convert vector:ID to store:ID
679+
doc_uuid = doc_id.split(':')[1]
680+
store_key = get_key_with_hash_tag(STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_uuid, self.cluster_mode)
677681
result_map[store_key] = doc
678682
pipeline.json().get(store_key)
679683

langgraph/store/redis/base.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@
4141
STORE_PREFIX = "store"
4242
STORE_VECTOR_PREFIX = "store_vectors"
4343

44+
def get_key_with_hash_tag(prefix: str, separator: str, id_value: str, use_hash_tag: bool = False) -> str:
45+
"""Create a Redis key with optional hash tag for cluster mode.
46+
47+
In Redis Cluster, keys with hash tags ensure they're stored in the same hash slot.
48+
Hash tags are substrings enclosed in curly braces {}.
49+
"""
50+
if use_hash_tag:
51+
# Use hash tag to ensure related keys are in the same slot
52+
return f"{prefix}{separator}{{{id_value}}}"
53+
else:
54+
return f"{prefix}{separator}{id_value}"
55+
4456
# Schemas for Redis Search indices
4557
SCHEMAS = [
4658
{
@@ -106,6 +118,7 @@ class BaseRedisStore(Generic[RedisClientType, IndexType]):
106118
vector_index: IndexType
107119
_ttl_sweeper_thread: Optional[threading.Thread] = None
108120
_ttl_stop_event: threading.Event | None = None
121+
cluster_mode: bool = False
109122

110123
SCHEMAS = SCHEMAS
111124

@@ -132,7 +145,8 @@ def _apply_ttl_to_keys(
132145

133146
if ttl_minutes is not None:
134147
ttl_seconds = int(ttl_minutes * 60)
135-
pipeline = self._redis.pipeline()
148+
# In cluster mode, we must use transaction=False
149+
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
136150

137151
# Set TTL for main key
138152
pipeline.expire(main_key, ttl_seconds)
@@ -192,6 +206,18 @@ def __init__(
192206
self.index_config = index
193207
self.ttl_config = ttl # type: ignore
194208
self.embeddings: Optional[Embeddings] = None
209+
210+
# Detect if Redis client is a cluster client
211+
from redis.exceptions import ResponseError
212+
try:
213+
# Try to run a cluster command
214+
# This will succeed for cluster clients and fail for non-cluster clients
215+
self._redis.cluster("info") # type: ignore
216+
self.cluster_mode = True
217+
logger.info("Redis cluster mode detected")
218+
except (ResponseError, AttributeError):
219+
self.cluster_mode = False
220+
logger.debug("Redis standalone mode detected")
195221
if self.index_config:
196222
self.index_config = self.index_config.copy()
197223
self.embeddings = ensure_embeddings(

0 commit comments

Comments
 (0)