|
7 | 7 | import logging |
8 | 8 | import os |
9 | 9 | from contextlib import asynccontextmanager |
10 | | -from functools import partial |
11 | 10 | from types import TracebackType |
12 | 11 | from typing import ( |
13 | 12 | Any, |
|
33 | 32 | get_checkpoint_id, |
34 | 33 | ) |
35 | 34 | from langgraph.constants import TASKS |
36 | | -from redis.asyncio import Redis as AsyncRedis |
37 | | -from redis.asyncio.client import Pipeline |
38 | | -from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster |
39 | | -from redisvl.index import AsyncSearchIndex |
40 | | -from redisvl.query import FilterQuery |
41 | | -from redisvl.query.filter import Num, Tag |
42 | | -from redisvl.redis.connection import RedisConnectionFactory |
43 | | - |
44 | 35 | from langgraph.checkpoint.redis.base import BaseRedisSaver |
45 | 36 | from langgraph.checkpoint.redis.util import ( |
46 | 37 | EMPTY_ID_SENTINEL, |
|
50 | 41 | to_storage_safe_id, |
51 | 42 | to_storage_safe_str, |
52 | 43 | ) |
| 44 | +from redis.asyncio import Redis as AsyncRedis |
| 45 | +from redis.asyncio.client import Pipeline |
| 46 | +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster |
| 47 | +from redisvl.index import AsyncSearchIndex |
| 48 | +from redisvl.query import FilterQuery |
| 49 | +from redisvl.query.filter import Num, Tag |
53 | 50 |
|
54 | 51 | logger = logging.getLogger(__name__) |
55 | 52 |
|
@@ -587,11 +584,11 @@ async def aput( |
587 | 584 |
|
588 | 585 | if self.cluster_mode: |
589 | 586 | # For cluster mode, execute operations individually |
590 | | - await self._redis.json().set(checkpoint_key, "$", checkpoint_data) |
| 587 | + await self._redis.json().set(checkpoint_key, "$", checkpoint_data) # type: ignore[misc] |
591 | 588 |
|
592 | 589 | if blobs: |
593 | 590 | for key, data in blobs: |
594 | | - await self._redis.json().set(key, "$", data) |
| 591 | + await self._redis.json().set(key, "$", data) # type: ignore[misc] |
595 | 592 |
|
596 | 593 | # Apply TTL if configured |
597 | 594 | if self.ttl_config and "default_ttl" in self.ttl_config: |
@@ -654,7 +651,7 @@ async def aput( |
654 | 651 |
|
655 | 652 | if self.cluster_mode: |
656 | 653 | # For cluster mode, execute operation directly |
657 | | - await self._redis.json().set( |
| 654 | + await self._redis.json().set( # type: ignore[misc] |
658 | 655 | checkpoint_key, "$", checkpoint_data |
659 | 656 | ) |
660 | 657 | else: |
@@ -739,24 +736,19 @@ async def aput_writes( |
739 | 736 | exists = await self._redis.exists(key) |
740 | 737 | if exists: |
741 | 738 | # Update existing key |
742 | | - await self._redis.json().set( |
743 | | - key, "$.channel", write_obj["channel"] |
744 | | - ) |
745 | | - await self._redis.json().set( |
746 | | - key, "$.type", write_obj["type"] |
747 | | - ) |
748 | | - await self._redis.json().set( |
749 | | - key, "$.blob", write_obj["blob"] |
750 | | - ) |
| 739 | + pipeline = self._redis.pipeline(transaction=True) |
| 740 | + pipeline.json().set(key, "$.channel", write_obj["channel"]) # type: ignore[arg-type] |
| 741 | + pipeline.json().set(key, "$.type", write_obj["type"]) # type: ignore[arg-type] |
| 742 | + pipeline.json().set(key, "$.blob", write_obj["blob"]) # type: ignore[arg-type] |
751 | 743 | else: |
752 | 744 | # Create new key |
753 | | - await self._redis.json().set(key, "$", write_obj) |
| 745 | + pipeline.json().set(key, "$", write_obj) |
754 | 746 | created_keys.append(key) |
755 | 747 | else: |
756 | 748 | # For non-upsert case, only set if key doesn't exist |
757 | 749 | exists = await self._redis.exists(key) |
758 | 750 | if not exists: |
759 | | - await self._redis.json().set(key, "$", write_obj) |
| 751 | + pipeline.json().set(key, "$", write_obj) |
760 | 752 | created_keys.append(key) |
761 | 753 |
|
762 | 754 | # Apply TTL to newly created keys |
@@ -788,9 +780,9 @@ async def aput_writes( |
788 | 780 | exists = await self._redis.exists(key) |
789 | 781 | if exists: |
790 | 782 | # Update existing key |
791 | | - pipeline.json().set(key, "$.channel", write_obj["channel"]) |
792 | | - pipeline.json().set(key, "$.type", write_obj["type"]) |
793 | | - pipeline.json().set(key, "$.blob", write_obj["blob"]) |
| 783 | + pipeline.json().set(key, "$.channel", write_obj["channel"]) # type: ignore[arg-type] |
| 784 | + pipeline.json().set(key, "$.type", write_obj["type"]) # type: ignore[arg-type] |
| 785 | + pipeline.json().set(key, "$.blob", write_obj["blob"]) # type: ignore[arg-type] |
794 | 786 | else: |
795 | 787 | # Create new key |
796 | 788 | pipeline.json().set(key, "$", write_obj) |
|
0 commit comments