77import logging
88import os
99from contextlib import asynccontextmanager
10- from functools import partial
1110from types import TracebackType
1211from typing import (
1312 Any ,
3433)
3534from langgraph .constants import TASKS
3635from redis .asyncio import Redis as AsyncRedis
37- from redis .asyncio .client import Pipeline
3836from redis .asyncio .cluster import RedisCluster as AsyncRedisCluster
3937from redisvl .index import AsyncSearchIndex
4038from redisvl .query import FilterQuery
4139from redisvl .query .filter import Num , Tag
42- from redisvl .redis .connection import RedisConnectionFactory
4340
4441from langgraph .checkpoint .redis .base import BaseRedisSaver
4542from langgraph .checkpoint .redis .util import (
5451logger = logging .getLogger (__name__ )
5552
5653
57- async def _write_obj_tx (
58- pipe : Pipeline ,
59- key : str ,
60- write_obj : Dict [str , Any ],
61- upsert_case : bool ,
62- ) -> None :
63- exists : int = await pipe .exists (key )
64- if upsert_case :
65- if exists :
66- await pipe .json ().set (key , "$.channel" , write_obj ["channel" ])
67- await pipe .json ().set (key , "$.type" , write_obj ["type" ])
68- await pipe .json ().set (key , "$.blob" , write_obj ["blob" ])
69- else :
70- await pipe .json ().set (key , "$" , write_obj )
71- else :
72- if not exists :
73- await pipe .json ().set (key , "$" , write_obj )
74-
75-
7654class AsyncRedisSaver (
7755 BaseRedisSaver [Union [AsyncRedis , AsyncRedisCluster ], AsyncSearchIndex ]
7856):
@@ -568,7 +546,7 @@ async def aput(
568546 # store at top-level for filters in list()
569547 if all (key in metadata for key in ["source" , "step" ]):
570548 checkpoint_data ["source" ] = metadata ["source" ]
571- checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
549+ checkpoint_data ["step" ] = metadata ["step" ]
572550
573551 # Prepare checkpoint key
574552 checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
@@ -587,11 +565,11 @@ async def aput(
587565
588566 if self .cluster_mode :
589567 # For cluster mode, execute operations individually
590- await self ._redis .json ().set (checkpoint_key , "$" , checkpoint_data )
568+ await self ._redis .json ().set (checkpoint_key , "$" , checkpoint_data ) # type: ignore[misc]
591569
592570 if blobs :
593571 for key , data in blobs :
594- await self ._redis .json ().set (key , "$" , data )
572+ await self ._redis .json ().set (key , "$" , data ) # type: ignore[misc]
595573
596574 # Apply TTL if configured
597575 if self .ttl_config and "default_ttl" in self .ttl_config :
@@ -604,12 +582,12 @@ async def aput(
604582 pipeline = self ._redis .pipeline (transaction = True )
605583
606584 # Add checkpoint data to pipeline
607- await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
585+ pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
608586
609587 if blobs :
610588 # Add all blob operations to the pipeline
611589 for key , data in blobs :
612- await pipeline .json ().set (key , "$" , data )
590+ pipeline .json ().set (key , "$" , data )
613591
614592 # Execute all operations atomically
615593 await pipeline .execute ()
@@ -654,13 +632,13 @@ async def aput(
654632
655633 if self .cluster_mode :
656634 # For cluster mode, execute operation directly
657- await self ._redis .json ().set (
635+ await self ._redis .json ().set ( # type: ignore[misc]
658636 checkpoint_key , "$" , checkpoint_data
659637 )
660638 else :
661639 # For non-cluster mode, use pipeline
662640 pipeline = self ._redis .pipeline (transaction = True )
663- await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
641+ pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
664642 await pipeline .execute ()
665643 except Exception :
666644 # If this also fails, we just propagate the original cancellation
@@ -739,24 +717,18 @@ async def aput_writes(
739717 exists = await self ._redis .exists (key )
740718 if exists :
741719 # Update existing key
742- await self ._redis .json ().set (
743- key , "$.channel" , write_obj ["channel" ]
744- )
745- await self ._redis .json ().set (
746- key , "$.type" , write_obj ["type" ]
747- )
748- await self ._redis .json ().set (
749- key , "$.blob" , write_obj ["blob" ]
750- )
720+ await self ._redis .json ().set (key , "$.channel" , write_obj ["channel" ]) # type: ignore[misc, arg-type]
721+ await self ._redis .json ().set (key , "$.type" , write_obj ["type" ]) # type: ignore[misc, arg-type]
722+ await self ._redis .json ().set (key , "$.blob" , write_obj ["blob" ]) # type: ignore[misc, arg-type]
751723 else :
752724 # Create new key
753- await self ._redis .json ().set (key , "$" , write_obj )
725+ await self ._redis .json ().set (key , "$" , write_obj ) # type: ignore[misc]
754726 created_keys .append (key )
755727 else :
756728 # For non-upsert case, only set if key doesn't exist
757729 exists = await self ._redis .exists (key )
758730 if not exists :
759- await self ._redis .json ().set (key , "$" , write_obj )
731+ await self ._redis .json ().set (key , "$" , write_obj ) # type: ignore[misc]
760732 created_keys .append (key )
761733
762734 # Apply TTL to newly created keys
@@ -788,20 +760,30 @@ async def aput_writes(
788760 exists = await self ._redis .exists (key )
789761 if exists :
790762 # Update existing key
791- await pipeline .json ().set (
792- key , "$.channel" , write_obj ["channel" ]
763+ pipeline .json ().set (
764+ key ,
765+ "$.channel" ,
766+ write_obj ["channel" ], # type: ignore[arg-type]
767+ )
768+ pipeline .json ().set (
769+ key ,
770+ "$.type" ,
771+ write_obj ["type" ], # type: ignore[arg-type]
772+ )
773+ pipeline .json ().set (
774+ key ,
775+ "$.blob" ,
776+ write_obj ["blob" ], # type: ignore[arg-type]
793777 )
794- await pipeline .json ().set (key , "$.type" , write_obj ["type" ])
795- await pipeline .json ().set (key , "$.blob" , write_obj ["blob" ])
796778 else :
797779 # Create new key
798- await pipeline .json ().set (key , "$" , write_obj )
780+ pipeline .json ().set (key , "$" , write_obj )
799781 created_keys .append (key )
800782 else :
801783 # For non-upsert case, only set if key doesn't exist
802784 exists = await self ._redis .exists (key )
803785 if not exists :
804- await pipeline .json ().set (key , "$" , write_obj )
786+ pipeline .json ().set (key , "$" , write_obj )
805787 created_keys .append (key )
806788
807789 # Execute all operations atomically
0 commit comments