Skip to content

Commit c8b9571

Browse files
refactor to use redisvl 0.4.1+
1 parent fed21c6 commit c8b9571

File tree

12 files changed

+230
-84
lines changed

12 files changed

+230
-84
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,20 @@ def configure_client(
5151
) -> None:
5252
"""Configure the Redis client."""
5353
self._owns_its_client = redis_client is None
54-
5554
self._redis = redis_client or RedisConnectionFactory.get_redis_connection(
5655
redis_url, **connection_args
5756
)
5857

5958
def create_indexes(self) -> None:
60-
self.checkpoints_index = SearchIndex.from_dict(self.SCHEMAS[0])
61-
self.checkpoint_blobs_index = SearchIndex.from_dict(self.SCHEMAS[1])
62-
self.checkpoint_writes_index = SearchIndex.from_dict(self.SCHEMAS[2])
63-
64-
# Connect Redis client to indices
65-
self.checkpoints_index.set_client(self._redis)
66-
self.checkpoint_blobs_index.set_client(self._redis)
67-
self.checkpoint_writes_index.set_client(self._redis)
59+
self.checkpoints_index = SearchIndex.from_dict(
60+
self.SCHEMAS[0], redis_client=self._redis
61+
)
62+
self.checkpoint_blobs_index = SearchIndex.from_dict(
63+
self.SCHEMAS[1], redis_client=self._redis
64+
)
65+
self.checkpoint_writes_index = SearchIndex.from_dict(
66+
self.SCHEMAS[2], redis_client=self._redis
67+
)
6868

6969
def list(
7070
self,

langgraph/checkpoint/redis/aio.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,15 @@ def configure_client(
8888

8989
def create_indexes(self) -> None:
9090
"""Create indexes without connecting to Redis."""
91-
self.checkpoints_index = AsyncSearchIndex.from_dict(self.SCHEMAS[0])
92-
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(self.SCHEMAS[1])
93-
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(self.SCHEMAS[2])
91+
self.checkpoints_index = AsyncSearchIndex.from_dict(
92+
self.SCHEMAS[0], redis_client=self._redis
93+
)
94+
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(
95+
self.SCHEMAS[1], redis_client=self._redis
96+
)
97+
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
98+
self.SCHEMAS[2], redis_client=self._redis
99+
)
94100

95101
async def __aenter__(self) -> AsyncRedisSaver:
96102
"""Async context manager enter."""
@@ -116,11 +122,6 @@ async def __aexit__(
116122

117123
async def asetup(self) -> None:
118124
"""Initialize Redis indexes asynchronously."""
119-
# Connect Redis client to indices asynchronously
120-
await self.checkpoints_index.set_client(self._redis)
121-
await self.checkpoint_blobs_index.set_client(self._redis)
122-
await self.checkpoint_writes_index.set_client(self._redis)
123-
124125
# Create indexes in Redis asynchronously
125126
await self.checkpoints_index.create(overwrite=False)
126127
await self.checkpoint_blobs_index.create(overwrite=False)

langgraph/checkpoint/redis/ashallow.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,6 @@ async def from_conn_string(
153153

154154
async def asetup(self) -> None:
155155
"""Initialize Redis indexes asynchronously."""
156-
# Connect Redis client to indices asynchronously
157-
await self.checkpoints_index.set_client(self._redis)
158-
await self.checkpoint_blobs_index.set_client(self._redis)
159-
await self.checkpoint_writes_index.set_client(self._redis)
160-
161156
# Create indexes in Redis asynchronously
162157
await self.checkpoints_index.create(overwrite=False)
163158
await self.checkpoint_blobs_index.create(overwrite=False)
@@ -557,9 +552,15 @@ def configure_client(
557552

558553
def create_indexes(self) -> None:
559554
"""Create indexes without connecting to Redis."""
560-
self.checkpoints_index = AsyncSearchIndex.from_dict(self.SCHEMAS[0])
561-
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(self.SCHEMAS[1])
562-
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(self.SCHEMAS[2])
555+
self.checkpoints_index = AsyncSearchIndex.from_dict(
556+
self.SCHEMAS[0], redis_client=self._redis
557+
)
558+
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(
559+
self.SCHEMAS[1], redis_client=self._redis
560+
)
561+
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
562+
self.SCHEMAS[2], redis_client=self._redis
563+
)
563564

564565
def setup(self) -> None:
565566
"""Initialize the checkpoint_index in Redis."""

langgraph/checkpoint/redis/shallow.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,15 @@ def configure_client(
388388
)
389389

390390
def create_indexes(self) -> None:
391-
self.checkpoints_index = SearchIndex.from_dict(self.SCHEMAS[0])
392-
self.checkpoint_blobs_index = SearchIndex.from_dict(self.SCHEMAS[1])
393-
self.checkpoint_writes_index = SearchIndex.from_dict(self.SCHEMAS[2])
394-
395-
# Connect Redis client to indices
396-
self.checkpoints_index.set_client(self._redis)
397-
self.checkpoint_blobs_index.set_client(self._redis)
398-
self.checkpoint_writes_index.set_client(self._redis)
391+
self.checkpoints_index = SearchIndex.from_dict(
392+
self.SCHEMAS[0], redis_client=self._redis
393+
)
394+
self.checkpoint_blobs_index = SearchIndex.from_dict(
395+
self.SCHEMAS[1], redis_client=self._redis
396+
)
397+
self.checkpoint_writes_index = SearchIndex.from_dict(
398+
self.SCHEMAS[2], redis_client=self._redis
399+
)
399400

400401
def put_writes(
401402
self,

langgraph/store/redis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from contextlib import contextmanager
99
from datetime import datetime, timezone
1010
from typing import Any, Iterable, Iterator, Optional, Sequence, cast
11-
from ulid import ULID
1211

1312
from langgraph.store.base import (
1413
BaseStore,
@@ -26,6 +25,7 @@
2625
from redisvl.query import FilterQuery, VectorQuery
2726
from redisvl.redis.connection import RedisConnectionFactory
2827
from redisvl.utils.token_escaper import TokenEscaper
28+
from ulid import ULID
2929

3030
from langgraph.store.redis.aio import AsyncRedisStore
3131
from langgraph.store.redis.base import (

langgraph/store/redis/aio.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from datetime import datetime, timezone
88
from types import TracebackType
99
from typing import Any, AsyncIterator, Iterable, Optional, Sequence, cast
10-
from ulid import ULID
1110

1211
from langgraph.store.base import (
1312
BaseStore,
@@ -29,6 +28,7 @@
2928
from redisvl.query import FilterQuery, VectorQuery
3029
from redisvl.redis.connection import RedisConnectionFactory
3130
from redisvl.utils.token_escaper import TokenEscaper
31+
from ulid import ULID
3232

3333
from langgraph.store.redis.base import (
3434
REDIS_KEY_SEPARATOR,
@@ -57,7 +57,7 @@ class AsyncRedisStore(
5757

5858
store_index: AsyncSearchIndex
5959
vector_index: AsyncSearchIndex
60-
_owns_client: bool
60+
_owns_its_client: bool
6161

6262
def __init__(
6363
self,
@@ -97,7 +97,9 @@ def __init__(
9797
self.configure_client(redis_url=redis_url, redis_client=redis_client)
9898

9999
# Create store index
100-
self.store_index = AsyncSearchIndex.from_dict(self.SCHEMAS[0])
100+
self.store_index = AsyncSearchIndex.from_dict(
101+
self.SCHEMAS[0], redis_client=self._redis
102+
)
101103

102104
# Configure vector index if needed
103105
if self.index_config:
@@ -131,7 +133,9 @@ def __init__(
131133
vector_field["attrs"].update(self.index_config["ann_index_config"])
132134

133135
try:
134-
self.vector_index = AsyncSearchIndex.from_dict(vector_schema)
136+
self.vector_index = AsyncSearchIndex.from_dict(
137+
vector_schema, redis_client=self._redis
138+
)
135139
except Exception as e:
136140
raise ValueError(
137141
f"Failed to create vector index with schema: {vector_schema}. Error: {str(e)}"
@@ -147,7 +151,7 @@ def configure_client(
147151
redis_client: Optional[AsyncRedis] = None,
148152
) -> None:
149153
"""Configure the Redis client."""
150-
self._owns_client = redis_client is None
154+
self._owns_its_client = redis_client is None
151155
self._redis = redis_client or RedisConnectionFactory.get_async_redis_connection(
152156
redis_url
153157
)
@@ -160,11 +164,6 @@ async def setup(self) -> None:
160164
self.index_config.get("embed"),
161165
)
162166

163-
# Now connect Redis client to indices
164-
await self.store_index.set_client(self._redis)
165-
if self.index_config:
166-
await self.vector_index.set_client(self._redis)
167-
168167
# Create indices in Redis
169168
await self.store_index.create(overwrite=False)
170169
if self.index_config:
@@ -188,9 +187,13 @@ async def from_conn_string(
188187

189188
def create_indexes(self) -> None:
190189
"""Create async indices."""
191-
self.store_index = AsyncSearchIndex.from_dict(self.SCHEMAS[0])
190+
self.store_index = AsyncSearchIndex.from_dict(
191+
self.SCHEMAS[0], redis_client=self._redis
192+
)
192193
if self.index_config:
193-
self.vector_index = AsyncSearchIndex.from_dict(self.SCHEMAS[1])
194+
self.vector_index = AsyncSearchIndex.from_dict(
195+
self.SCHEMAS[1], redis_client=self._redis
196+
)
194197

195198
async def __aenter__(self) -> AsyncRedisStore:
196199
"""Async context manager enter."""
@@ -210,7 +213,7 @@ async def __aexit__(
210213
except asyncio.CancelledError:
211214
pass
212215

213-
if self._owns_client:
216+
if self._owns_its_client:
214217
await self._redis.aclose() # type: ignore[attr-defined]
215218
await self._redis.connection_pool.disconnect()
216219

langgraph/store/redis/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def __init__(
121121
]
122122

123123
# Initialize search indices
124-
self.store_index = SearchIndex.from_dict(self.SCHEMAS[0])
125-
self.store_index.set_client(self._redis)
124+
self.store_index = SearchIndex.from_dict(
125+
self.SCHEMAS[0], redis_client=self._redis
126+
)
126127

127128
# Configure vector index if needed
128129
if self.index_config:
@@ -156,8 +157,9 @@ def __init__(
156157
if "ann_index_config" in self.index_config:
157158
vector_field["attrs"].update(self.index_config["ann_index_config"])
158159

159-
self.vector_index = SearchIndex.from_dict(vector_schema)
160-
self.vector_index.set_client(self._redis)
160+
self.vector_index = SearchIndex.from_dict(
161+
vector_schema, redis_client=self._redis
162+
)
161163

162164
def _get_batch_GET_ops_queries(
163165
self,

0 commit comments

Comments
 (0)