|
8 | 8 | from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast |
9 | 9 |
|
10 | 10 | from langchain_core.runnables import RunnableConfig |
| 11 | +from redis import WatchError |
11 | 12 | from redisvl.index import AsyncSearchIndex |
12 | 13 | from redisvl.query import FilterQuery |
13 | 14 | from redisvl.query.filter import Num, Tag |
@@ -317,9 +318,9 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: |
317 | 318 |
|
318 | 319 | # Ensure metadata matches CheckpointMetadata type |
319 | 320 | sanitized_metadata = { |
320 | | - k.replace("\u0000", ""): v.replace("\u0000", "") |
321 | | - if isinstance(v, str) |
322 | | - else v |
| 321 | + k.replace("\u0000", ""): ( |
| 322 | + v.replace("\u0000", "") if isinstance(v, str) else v |
| 323 | + ) |
323 | 324 | for k, v in metadata_dict.items() |
324 | 325 | } |
325 | 326 | metadata = cast(CheckpointMetadata, sanitized_metadata) |
@@ -386,37 +387,32 @@ async def aput_writes( |
386 | 387 | } |
387 | 388 | writes_objects.append(write_obj) |
388 | 389 |
|
389 | | - # For each write, check existence and then perform appropriate operation |
390 | | - async with self.checkpoints_index.client.pipeline( |
391 | | - transaction=False |
392 | | - ) as pipeline: |
393 | | - for write_obj in writes_objects: |
394 | | - key = self._make_redis_checkpoint_writes_key( |
395 | | - thread_id, |
396 | | - checkpoint_ns, |
397 | | - checkpoint_id, |
398 | | - task_id, |
399 | | - write_obj["idx"], |
400 | | - ) |
401 | | - |
402 | | - # First check if key exists |
403 | | - key_exists = await self._redis.exists(key) == 1 |
404 | | - |
405 | | - if all(w[0] in WRITES_IDX_MAP for w in writes): |
406 | | - # UPSERT case - only update specific fields |
407 | | - if key_exists: |
408 | | - # Update only channel, type, and blob fields |
409 | | - pipeline.json().set(key, "$.channel", write_obj["channel"]) |
410 | | - pipeline.json().set(key, "$.type", write_obj["type"]) |
411 | | - pipeline.json().set(key, "$.blob", write_obj["blob"]) |
| 390 | + upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes) |
| 391 | + for write_obj in writes_objects: |
| 392 | + key = self._make_redis_checkpoint_writes_key( |
| 393 | + thread_id, |
| 394 | + checkpoint_ns, |
| 395 | + checkpoint_id, |
| 396 | + task_id, |
| 397 | + write_obj["idx"], |
| 398 | + ) |
| 399 | + if upsert_case: |
| 400 | + async def tx(pipe, key=key, write_obj=write_obj): |
| 401 | + exists = await pipe.exists(key) |
| 402 | + if exists: |
| 403 | + await pipe.json().set( |
| 404 | + key, "$.channel", write_obj["channel"] |
| 405 | + ) |
| 406 | + await pipe.json().set(key, "$.type", write_obj["type"]) |
| 407 | + await pipe.json().set(key, "$.blob", write_obj["blob"]) |
412 | 408 | else: |
413 | | - # For new records, set the complete object |
414 | | - pipeline.json().set(key, "$", write_obj) |
415 | | - else: |
416 | | - # INSERT case - only insert if doesn't exist |
417 | | - pipeline.json().set(key, "$", write_obj) |
| 409 | + await pipe.json().set(key, "$", write_obj) |
418 | 410 |
|
419 | | - await pipeline.execute() |
| 411 | + await self._redis.transaction(tx, key) |
| 412 | + else: |
| 413 | + # Unlike AsyncRedisSaver, the shallow implementation always overwrites |
| 414 | + # the full object, so we don't check for key existence here. |
| 415 | + await self._redis.json().set(key, "$", write_obj) |
420 | 416 |
|
421 | 417 | async def aget_channel_values( |
422 | 418 | self, thread_id: str, checkpoint_ns: str, checkpoint_id: str |
|
0 commit comments