Skip to content

Commit dbbbb45

Browse files
committed
Use watches on compare-and-set pipelines
1 parent e8b89e9 commit dbbbb45

File tree

2 files changed

+49
-62
lines changed

2 files changed

+49
-62
lines changed

langgraph/checkpoint/redis/aio.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, List, Optional, Sequence, Tuple, Type, cast
1111

1212
from langchain_core.runnables import RunnableConfig
13+
from redis import WatchError
1314
from redisvl.index import AsyncSearchIndex
1415
from redisvl.query import FilterQuery
1516
from redisvl.query.filter import Num, Tag
@@ -418,38 +419,28 @@ async def aput_writes(
418419
}
419420
writes_objects.append(write_obj)
420421

421-
# For each write, check existence and then perform appropriate operation
422-
async with self.checkpoints_index.client.pipeline(
423-
transaction=False
424-
) as pipeline:
425-
for write_obj in writes_objects:
426-
key = self._make_redis_checkpoint_writes_key(
427-
thread_id,
428-
checkpoint_ns,
429-
checkpoint_id,
430-
task_id,
431-
write_obj["idx"],
432-
)
433-
434-
# First check if key exists
435-
key_exists = await self._redis.exists(key) == 1
436-
437-
if all(w[0] in WRITES_IDX_MAP for w in writes):
438-
# UPSERT case - only update specific fields
439-
if key_exists:
440-
# Update only channel, type, and blob fields
441-
pipeline.json().set(key, "$.channel", write_obj["channel"])
442-
pipeline.json().set(key, "$.type", write_obj["type"])
443-
pipeline.json().set(key, "$.blob", write_obj["blob"])
422+
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
423+
for write_obj in writes_objects:
424+
key = self._make_redis_checkpoint_writes_key(
425+
thread_id,
426+
checkpoint_ns,
427+
checkpoint_id,
428+
task_id,
429+
write_obj["idx"],
430+
)
431+
async def tx(pipe, key=key, write_obj=write_obj, upsert_case=upsert_case):
432+
exists = await pipe.exists(key)
433+
if upsert_case:
434+
if exists:
435+
await pipe.json().set(key, "$.channel", write_obj["channel"])
436+
await pipe.json().set(key, "$.type", write_obj["type"])
437+
await pipe.json().set(key, "$.blob", write_obj["blob"])
444438
else:
445-
# For new records, set the complete object
446-
pipeline.json().set(key, "$", write_obj)
439+
await pipe.json().set(key, "$", write_obj)
447440
else:
448-
# INSERT case - only insert if doesn't exist
449-
if not key_exists:
450-
pipeline.json().set(key, "$", write_obj)
451-
452-
await pipeline.execute()
441+
if not exists:
442+
await pipe.json().set(key, "$", write_obj)
443+
await self._redis.transaction(tx, key)
453444

454445
def put_writes(
455446
self,

langgraph/checkpoint/redis/ashallow.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast
99

1010
from langchain_core.runnables import RunnableConfig
11+
from redis import WatchError
1112
from redisvl.index import AsyncSearchIndex
1213
from redisvl.query import FilterQuery
1314
from redisvl.query.filter import Num, Tag
@@ -317,9 +318,9 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
317318

318319
# Ensure metadata matches CheckpointMetadata type
319320
sanitized_metadata = {
320-
k.replace("\u0000", ""): v.replace("\u0000", "")
321-
if isinstance(v, str)
322-
else v
321+
k.replace("\u0000", ""): (
322+
v.replace("\u0000", "") if isinstance(v, str) else v
323+
)
323324
for k, v in metadata_dict.items()
324325
}
325326
metadata = cast(CheckpointMetadata, sanitized_metadata)
@@ -386,37 +387,32 @@ async def aput_writes(
386387
}
387388
writes_objects.append(write_obj)
388389

389-
# For each write, check existence and then perform appropriate operation
390-
async with self.checkpoints_index.client.pipeline(
391-
transaction=False
392-
) as pipeline:
393-
for write_obj in writes_objects:
394-
key = self._make_redis_checkpoint_writes_key(
395-
thread_id,
396-
checkpoint_ns,
397-
checkpoint_id,
398-
task_id,
399-
write_obj["idx"],
400-
)
401-
402-
# First check if key exists
403-
key_exists = await self._redis.exists(key) == 1
404-
405-
if all(w[0] in WRITES_IDX_MAP for w in writes):
406-
# UPSERT case - only update specific fields
407-
if key_exists:
408-
# Update only channel, type, and blob fields
409-
pipeline.json().set(key, "$.channel", write_obj["channel"])
410-
pipeline.json().set(key, "$.type", write_obj["type"])
411-
pipeline.json().set(key, "$.blob", write_obj["blob"])
390+
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
391+
for write_obj in writes_objects:
392+
key = self._make_redis_checkpoint_writes_key(
393+
thread_id,
394+
checkpoint_ns,
395+
checkpoint_id,
396+
task_id,
397+
write_obj["idx"],
398+
)
399+
if upsert_case:
400+
async def tx(pipe, key=key, write_obj=write_obj):
401+
exists = await pipe.exists(key)
402+
if exists:
403+
await pipe.json().set(
404+
key, "$.channel", write_obj["channel"]
405+
)
406+
await pipe.json().set(key, "$.type", write_obj["type"])
407+
await pipe.json().set(key, "$.blob", write_obj["blob"])
412408
else:
413-
# For new records, set the complete object
414-
pipeline.json().set(key, "$", write_obj)
415-
else:
416-
# INSERT case - only insert if doesn't exist
417-
pipeline.json().set(key, "$", write_obj)
409+
await pipe.json().set(key, "$", write_obj)
418410

419-
await pipeline.execute()
411+
await self._redis.transaction(tx, key)
412+
else:
413+
# Unlike AsyncRedisSaver, the shallow implementation always overwrites
414+
# the full object, so we don't check for key existence here.
415+
await self._redis.json().set(key, "$", write_obj)
420416

421417
async def aget_channel_values(
422418
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str

0 commit comments

Comments
 (0)