Skip to content

Commit 2fbc882

Browse files
committed
Disable multi-key commands if client is cluster-aware
1 parent 27237ff commit 2fbc882

File tree

8 files changed

+1195
-586
lines changed

8 files changed

+1195
-586
lines changed

langgraph/store/redis/__init__.py

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
TTLConfig,
2222
)
2323
from redis import Redis
24+
from redis.cluster import RedisCluster
2425
from redis.commands.search.query import Query
26+
from redis.exceptions import ResponseError
2527
from redisvl.index import SearchIndex
2628
from redisvl.query import FilterQuery, VectorQuery
2729
from redisvl.redis.connection import RedisConnectionFactory
@@ -40,7 +42,7 @@
4042
_namespace_to_text,
4143
_row_to_item,
4244
_row_to_search_item,
43-
get_key_with_hash_tag,
45+
logger,
4446
)
4547

4648
from .token_unescaper import TokenUnescaper
@@ -81,10 +83,11 @@ def __init__(
8183
conn: Redis,
8284
*,
8385
index: Optional[IndexConfig] = None,
84-
ttl: Optional[dict[str, Any]] = None,
86+
ttl: Optional[TTLConfig] = None,
8587
) -> None:
8688
BaseStore.__init__(self)
8789
BaseRedisStore.__init__(self, conn, index=index, ttl=ttl)
90+
self._detect_cluster_mode()
8891

8992
@classmethod
9093
@contextmanager
@@ -93,7 +96,7 @@ def from_conn_string(
9396
conn_string: str,
9497
*,
9598
index: Optional[IndexConfig] = None,
96-
ttl: Optional[dict[str, Any]] = None,
99+
ttl: Optional[TTLConfig] = None,
97100
) -> Iterator[RedisStore]:
98101
"""Create store from Redis connection string."""
99102
client = None
@@ -111,6 +114,9 @@ def from_conn_string(
111114

112115
def setup(self) -> None:
113116
"""Initialize store indices."""
117+
# Detect if we're connected to a Redis cluster
118+
self._detect_cluster_mode()
119+
114120
self.store_index.create(overwrite=False)
115121
if self.index_config:
116122
self.vector_index.create(overwrite=False)
@@ -144,6 +150,19 @@ def batch(self, ops: Iterable[Op]) -> list[Result]:
144150

145151
return results
146152

153+
def _detect_cluster_mode(self) -> None:
154+
"""Detect if the Redis client is connected to a cluster."""
155+
try:
156+
# Try to run a cluster command
157+
# This will succeed for Redis clusters and fail for non-cluster servers
158+
self._redis.cluster("info")
159+
self.cluster_mode = True
160+
logger.info("Redis cluster mode detected for RedisStore.")
161+
except (ResponseError, AttributeError):
162+
self.cluster_mode = False
163+
logger.info("Redis standalone mode detected for RedisStore.")
164+
165+
147166
def _batch_list_namespaces_ops(
148167
self,
149168
list_ops: Sequence[tuple[int, ListNamespacesOp]],
@@ -246,17 +265,22 @@ def _batch_get_ops(
246265

247266
if ttl_minutes is not None:
248267
ttl_seconds = int(ttl_minutes * 60)
249-
# In cluster mode, we must use transaction=False
250-
pipeline = self._redis.pipeline(transaction=not self.cluster_mode)
251-
252-
for keys in refresh_keys_by_idx.values():
253-
for key in keys:
254-
# Only refresh TTL if the key exists and has a TTL
255-
ttl = self._redis.ttl(key)
256-
if ttl > 0: # Only refresh if key exists and has TTL
257-
pipeline.expire(key, ttl_seconds)
258-
259-
pipeline.execute()
268+
if self.cluster_mode:
269+
for keys_to_refresh in refresh_keys_by_idx.values():
270+
for key in keys_to_refresh:
271+
ttl = self._redis.ttl(key)
272+
if ttl > 0:
273+
self._redis.expire(key, ttl_seconds)
274+
else:
275+
pipeline = self._redis.pipeline(transaction=True)
276+
for keys in refresh_keys_by_idx.values():
277+
for key in keys:
278+
# Only refresh TTL if the key exists and has a TTL
279+
ttl = self._redis.ttl(key)
280+
if ttl > 0: # Only refresh if key exists and has TTL
281+
pipeline.expire(key, ttl_seconds)
282+
if pipeline.command_stack:
283+
pipeline.execute()
260284

261285
def _batch_put_ops(
262286
self,
@@ -270,12 +294,26 @@ def _batch_put_ops(
270294
namespace = _namespace_to_text(op.namespace)
271295
query = f"@prefix:{namespace} @key:{{{_token_escaper.escape(op.key)}}}"
272296
results = self.store_index.search(query)
273-
for doc in results.docs:
274-
self._redis.delete(doc.id)
275-
if self.index_config:
276-
results = self.vector_index.search(query)
297+
298+
if self.cluster_mode:
277299
for doc in results.docs:
278300
self._redis.delete(doc.id)
301+
if self.index_config:
302+
vector_results = self.vector_index.search(query)
303+
for doc_vec in vector_results.docs:
304+
self._redis.delete(doc_vec.id)
305+
else:
306+
pipeline = self._redis.pipeline(transaction=True)
307+
for doc in results.docs:
308+
pipeline.delete(doc.id)
309+
310+
if self.index_config:
311+
vector_results = self.vector_index.search(query)
312+
for doc_vec in vector_results.docs:
313+
pipeline.delete(doc_vec.id)
314+
315+
if pipeline.command_stack:
316+
pipeline.execute()
279317

280318
# Now handle new document creation
281319
doc_ids: dict[tuple[str, str], str] = {}
@@ -293,12 +331,7 @@ def _batch_put_ops(
293331
doc_ids[(namespace, op.key)] = generated_doc_id
294332
# Track TTL for this document if specified
295333
if hasattr(op, "ttl") and op.ttl is not None:
296-
main_key = get_key_with_hash_tag(
297-
STORE_PREFIX,
298-
REDIS_KEY_SEPARATOR,
299-
generated_doc_id,
300-
self.cluster_mode,
301-
)
334+
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{generated_doc_id}"
302335
ttl_tracking[main_key] = ([], op.ttl)
303336

304337
# Load store docs with explicit keys
@@ -312,13 +345,16 @@ def _batch_put_ops(
312345
doc.pop("expires_at", None)
313346

314347
store_docs.append(doc)
315-
redis_key = get_key_with_hash_tag(
316-
STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
317-
)
348+
redis_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
318349
store_keys.append(redis_key)
319350

320351
if store_docs:
321-
self.store_index.load(store_docs, keys=store_keys)
352+
if self.cluster_mode:
353+
# Load individually if cluster
354+
for i, store_doc_item in enumerate(store_docs):
355+
self.store_index.load([store_doc_item], keys=[store_keys[i]])
356+
else:
357+
self.store_index.load(store_docs, keys=store_keys)
322358

323359
# Handle vector embeddings with same IDs
324360
if embedding_request and self.embeddings:
@@ -344,20 +380,21 @@ def _batch_put_ops(
344380
"updated_at": datetime.now(timezone.utc).timestamp(),
345381
}
346382
)
347-
vector_key = get_key_with_hash_tag(
348-
STORE_VECTOR_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
349-
)
350-
vector_keys.append(vector_key)
383+
redis_vector_key = f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
384+
vector_keys.append(redis_vector_key)
351385

352386
# Add this vector key to the related keys list for TTL
353-
main_key = get_key_with_hash_tag(
354-
STORE_PREFIX, REDIS_KEY_SEPARATOR, doc_id, self.cluster_mode
355-
)
387+
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
356388
if main_key in ttl_tracking:
357-
ttl_tracking[main_key][0].append(vector_key)
389+
ttl_tracking[main_key][0].append(redis_vector_key)
358390

359391
if vector_docs:
360-
self.vector_index.load(vector_docs, keys=vector_keys)
392+
if self.cluster_mode:
393+
# Load individually if cluster
394+
for i, vector_doc_item in enumerate(vector_docs):
395+
self.vector_index.load([vector_doc_item], keys=[vector_keys[i]])
396+
else:
397+
self.vector_index.load(vector_docs, keys=vector_keys)
361398

362399
# Now apply TTLs after all documents are loaded
363400
for main_key, (related_keys, ttl_minutes) in ttl_tracking.items():
@@ -407,12 +444,7 @@ def _batch_search_ops(
407444
if doc_id:
408445
# Convert vector:ID to store:ID
409446
doc_uuid = doc_id.split(":")[1]
410-
store_key = get_key_with_hash_tag(
411-
STORE_PREFIX,
412-
REDIS_KEY_SEPARATOR,
413-
doc_uuid,
414-
self.cluster_mode,
415-
)
447+
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
416448
result_map[store_key] = doc
417449
pipe.json().get(store_key)
418450

@@ -457,11 +489,8 @@ def _batch_search_ops(
457489
refresh_keys.append(store_key)
458490
# Also find associated vector keys with same ID
459491
doc_id = store_key.split(":")[-1]
460-
vector_key = get_key_with_hash_tag(
461-
STORE_VECTOR_PREFIX,
462-
REDIS_KEY_SEPARATOR,
463-
doc_id,
464-
self.cluster_mode,
492+
vector_key = (
493+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
465494
)
466495
refresh_keys.append(vector_key)
467496

@@ -482,16 +511,20 @@ def _batch_search_ops(
482511

483512
if ttl_minutes is not None:
484513
ttl_seconds = int(ttl_minutes * 60)
485-
# In cluster mode, we must use transaction=False
486-
pipeline = self._redis.pipeline(
487-
transaction=not self.cluster_mode
488-
)
489-
for key in refresh_keys:
490-
# Only refresh TTL if the key exists and has a TTL
491-
ttl = self._redis.ttl(key)
492-
if ttl > 0: # Only refresh if key exists and has TTL
493-
pipeline.expire(key, ttl_seconds)
494-
pipeline.execute()
514+
if self.cluster_mode:
515+
for key in refresh_keys:
516+
ttl = self._redis.ttl(key)
517+
if ttl > 0:
518+
self._redis.expire(key, ttl_seconds)
519+
else:
520+
pipeline = self._redis.pipeline(transaction=True)
521+
for key in refresh_keys:
522+
# Only refresh TTL if the key exists and has a TTL
523+
ttl = self._redis.ttl(key)
524+
if ttl > 0: # Only refresh if key exists and has TTL
525+
pipeline.expire(key, ttl_seconds)
526+
if pipeline.command_stack:
527+
pipeline.execute()
495528

496529
results[idx] = items
497530
else:
@@ -527,11 +560,8 @@ def _batch_search_ops(
527560
refresh_keys.append(doc.id)
528561
# Also find associated vector keys with same ID
529562
doc_id = doc.id.split(":")[-1]
530-
vector_key = get_key_with_hash_tag(
531-
STORE_VECTOR_PREFIX,
532-
REDIS_KEY_SEPARATOR,
533-
doc_id,
534-
self.cluster_mode,
563+
vector_key = (
564+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
535565
)
536566
refresh_keys.append(vector_key)
537567

@@ -548,16 +578,20 @@ def _batch_search_ops(
548578

549579
if ttl_minutes is not None:
550580
ttl_seconds = int(ttl_minutes * 60)
551-
# In cluster mode, we must use transaction=False
552-
pipeline = self._redis.pipeline(
553-
transaction=not self.cluster_mode
554-
)
555-
for key in refresh_keys:
556-
# Only refresh TTL if the key exists and has a TTL
557-
ttl = self._redis.ttl(key)
558-
if ttl > 0: # Only refresh if key exists and has TTL
559-
pipeline.expire(key, ttl_seconds)
560-
pipeline.execute()
581+
if self.cluster_mode:
582+
for key in refresh_keys:
583+
ttl = self._redis.ttl(key)
584+
if ttl > 0:
585+
self._redis.expire(key, ttl_seconds)
586+
else:
587+
pipeline = self._redis.pipeline(transaction=True)
588+
for key in refresh_keys:
589+
# Only refresh TTL if the key exists and has a TTL
590+
ttl = self._redis.ttl(key)
591+
if ttl > 0: # Only refresh if key exists and has TTL
592+
pipeline.expire(key, ttl_seconds)
593+
if pipeline.command_stack:
594+
pipeline.execute()
561595

562596
results[idx] = items
563597

0 commit comments

Comments
 (0)