77from datetime import datetime , timezone
88from types import TracebackType
99from typing import Any , AsyncIterator , Iterable , Optional , Sequence , cast
10- from ulid import ULID
1110
1211from langgraph .store .base import (
1312 BaseStore ,
2928from redisvl .query import FilterQuery , VectorQuery
3029from redisvl .redis .connection import RedisConnectionFactory
3130from redisvl .utils .token_escaper import TokenEscaper
31+ from ulid import ULID
3232
3333from langgraph .store .redis .base import (
3434 REDIS_KEY_SEPARATOR ,
@@ -57,14 +57,15 @@ class AsyncRedisStore(
5757
5858 store_index : AsyncSearchIndex
5959 vector_index : AsyncSearchIndex
60- _owns_client : bool
60+ _owns_its_client : bool
6161
6262 def __init__ (
6363 self ,
6464 redis_url : Optional [str ] = None ,
6565 * ,
6666 redis_client : Optional [AsyncRedis ] = None ,
6767 index : Optional [IndexConfig ] = None ,
68+ connection_args : Optional [dict [str , Any ]] = None ,
6869 ) -> None :
6970 """Initialize store with Redis connection and optional index config."""
7071 if redis_url is None and redis_client is None :
@@ -94,10 +95,16 @@ def __init__(
9495 ]
9596
9697 # Configure client
97- self .configure_client (redis_url = redis_url , redis_client = redis_client )
98+ self .configure_client (
99+ redis_url = redis_url ,
100+ redis_client = redis_client ,
101+ connection_args = connection_args or {},
102+ )
98103
99104 # Create store index
100- self .store_index = AsyncSearchIndex .from_dict (self .SCHEMAS [0 ])
105+ self .store_index = AsyncSearchIndex .from_dict (
106+ self .SCHEMAS [0 ], redis_client = self ._redis
107+ )
101108
102109 # Configure vector index if needed
103110 if self .index_config :
@@ -131,7 +138,9 @@ def __init__(
131138 vector_field ["attrs" ].update (self .index_config ["ann_index_config" ])
132139
133140 try :
134- self .vector_index = AsyncSearchIndex .from_dict (vector_schema )
141+ self .vector_index = AsyncSearchIndex .from_dict (
142+ vector_schema , redis_client = self ._redis
143+ )
135144 except Exception as e :
136145 raise ValueError (
137146 f"Failed to create vector index with schema: { vector_schema } . Error: { str (e )} "
@@ -145,11 +154,12 @@ def configure_client(
145154 self ,
146155 redis_url : Optional [str ] = None ,
147156 redis_client : Optional [AsyncRedis ] = None ,
157+ connection_args : Optional [dict [str , Any ]] = None ,
148158 ) -> None :
149159 """Configure the Redis client."""
150- self ._owns_client = redis_client is None
160+ self ._owns_its_client = redis_client is None
151161 self ._redis = redis_client or RedisConnectionFactory .get_async_redis_connection (
152- redis_url
162+ redis_url , ** connection_args
153163 )
154164
155165 async def setup (self ) -> None :
@@ -160,11 +170,6 @@ async def setup(self) -> None:
160170 self .index_config .get ("embed" ),
161171 )
162172
163- # Now connect Redis client to indices
164- await self .store_index .set_client (self ._redis )
165- if self .index_config :
166- await self .vector_index .set_client (self ._redis )
167-
168173 # Create indices in Redis
169174 await self .store_index .create (overwrite = False )
170175 if self .index_config :
@@ -188,9 +193,13 @@ async def from_conn_string(
188193
189194 def create_indexes (self ) -> None :
190195 """Create async indices."""
191- self .store_index = AsyncSearchIndex .from_dict (self .SCHEMAS [0 ])
196+ self .store_index = AsyncSearchIndex .from_dict (
197+ self .SCHEMAS [0 ], redis_client = self ._redis
198+ )
192199 if self .index_config :
193- self .vector_index = AsyncSearchIndex .from_dict (self .SCHEMAS [1 ])
200+ self .vector_index = AsyncSearchIndex .from_dict (
201+ self .SCHEMAS [1 ], redis_client = self ._redis
202+ )
194203
195204 async def __aenter__ (self ) -> AsyncRedisStore :
196205 """Async context manager enter."""
@@ -210,7 +219,7 @@ async def __aexit__(
210219 except asyncio .CancelledError :
211220 pass
212221
213- if self ._owns_client :
222+ if self ._owns_its_client :
214223 await self ._redis .aclose () # type: ignore[attr-defined]
215224 await self ._redis .connection_pool .disconnect ()
216225
0 commit comments