77from datetime import datetime , timezone
88from types import TracebackType
99from typing import Any , AsyncIterator , Iterable , Optional , Sequence , cast
10- from ulid import ULID
1110
1211from langgraph .store .base import (
1312 BaseStore ,
2928from redisvl .query import FilterQuery , VectorQuery
3029from redisvl .redis .connection import RedisConnectionFactory
3130from redisvl .utils .token_escaper import TokenEscaper
31+ from ulid import ULID
3232
3333from langgraph .store .redis .base import (
3434 REDIS_KEY_SEPARATOR ,
@@ -57,7 +57,7 @@ class AsyncRedisStore(
5757
5858 store_index : AsyncSearchIndex
5959 vector_index : AsyncSearchIndex
60- _owns_client : bool
60+ _owns_its_client : bool
6161
6262 def __init__ (
6363 self ,
@@ -97,7 +97,9 @@ def __init__(
9797 self .configure_client (redis_url = redis_url , redis_client = redis_client )
9898
9999 # Create store index
100- self .store_index = AsyncSearchIndex .from_dict (self .SCHEMAS [0 ])
100+ self .store_index = AsyncSearchIndex .from_dict (
101+ self .SCHEMAS [0 ], redis_client = self ._redis
102+ )
101103
102104 # Configure vector index if needed
103105 if self .index_config :
@@ -131,7 +133,9 @@ def __init__(
131133 vector_field ["attrs" ].update (self .index_config ["ann_index_config" ])
132134
133135 try :
134- self .vector_index = AsyncSearchIndex .from_dict (vector_schema )
136+ self .vector_index = AsyncSearchIndex .from_dict (
137+ vector_schema , redis_client = self ._redis
138+ )
135139 except Exception as e :
136140 raise ValueError (
137141 f"Failed to create vector index with schema: { vector_schema } . Error: { str (e )} "
@@ -147,7 +151,7 @@ def configure_client(
147151 redis_client : Optional [AsyncRedis ] = None ,
148152 ) -> None :
149153 """Configure the Redis client."""
150- self ._owns_client = redis_client is None
154+ self ._owns_its_client = redis_client is None
151155 self ._redis = redis_client or RedisConnectionFactory .get_async_redis_connection (
152156 redis_url
153157 )
@@ -160,11 +164,6 @@ async def setup(self) -> None:
160164 self .index_config .get ("embed" ),
161165 )
162166
163- # Now connect Redis client to indices
164- await self .store_index .set_client (self ._redis )
165- if self .index_config :
166- await self .vector_index .set_client (self ._redis )
167-
168167 # Create indices in Redis
169168 await self .store_index .create (overwrite = False )
170169 if self .index_config :
@@ -188,9 +187,13 @@ async def from_conn_string(
188187
189188 def create_indexes (self ) -> None :
190189 """Create async indices."""
191- self .store_index = AsyncSearchIndex .from_dict (self .SCHEMAS [0 ])
190+ self .store_index = AsyncSearchIndex .from_dict (
191+ self .SCHEMAS [0 ], redis_client = self ._redis
192+ )
192193 if self .index_config :
193- self .vector_index = AsyncSearchIndex .from_dict (self .SCHEMAS [1 ])
194+ self .vector_index = AsyncSearchIndex .from_dict (
195+ self .SCHEMAS [1 ], redis_client = self ._redis
196+ )
194197
195198 async def __aenter__ (self ) -> AsyncRedisStore :
196199 """Async context manager enter."""
@@ -210,7 +213,7 @@ async def __aexit__(
210213 except asyncio .CancelledError :
211214 pass
212215
213- if self ._owns_client :
216+ if self ._owns_its_client :
214217 await self ._redis .aclose () # type: ignore[attr-defined]
215218 await self ._redis .connection_pool .disconnect ()
216219
0 commit comments