Skip to content

Commit 98e80de

Browse files
committed
New and old pipeline, async fixes
1 parent 71df6d4 commit 98e80de

File tree

4 files changed

+53
-62
lines changed

4 files changed

+53
-62
lines changed

langgraph/checkpoint/redis/aio.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, List, Optional, Sequence, Tuple, Type, cast
1111

1212
from langchain_core.runnables import RunnableConfig
13-
from redis import WatchError
1413
from redisvl.index import AsyncSearchIndex
1514
from redisvl.query import FilterQuery
1615
from redisvl.query.filter import Num, Tag
@@ -74,6 +73,7 @@ def create_indexes(self) -> None:
7473

7574
async def __aenter__(self) -> AsyncRedisSaver:
7675
"""Async context manager enter."""
76+
await self.asetup()
7777
return self
7878

7979
async def __aexit__(
@@ -83,15 +83,15 @@ async def __aexit__(
8383
exc_tb: Optional[TracebackType],
8484
) -> None:
8585
"""Async context manager exit."""
86-
# Close client connections
87-
if hasattr(self, "checkpoint_index") and hasattr(
88-
self.checkpoint_index, "client"
89-
):
90-
await self.checkpoint_index.client.aclose()
91-
if hasattr(self, "channel_index") and hasattr(self.channel_index, "client"):
92-
await self.channel_index.client.aclose()
93-
if hasattr(self, "writes_index") and hasattr(self.writes_index, "client"):
94-
await self.writes_index.client.aclose()
86+
if self._owns_its_client:
87+
await self._redis.aclose() # type: ignore[attr-defined]
88+
await self._redis.connection_pool.disconnect()
89+
90+
# Prevent RedisVL from attempting to close the client
91+
# on an event loop in a separate thread.
92+
self.checkpoints_index._redis_client = None
93+
self.checkpoint_blobs_index._redis_client = None
94+
self.checkpoint_writes_index._redis_client = None
9595

9696
async def asetup(self) -> None:
9797
"""Initialize Redis indexes asynchronously."""
@@ -428,18 +428,24 @@ async def aput_writes(
428428
task_id,
429429
write_obj["idx"],
430430
)
431-
async def tx(pipe, key=key, write_obj=write_obj, upsert_case=upsert_case):
431+
432+
async def tx(
433+
pipe, key=key, write_obj=write_obj, upsert_case=upsert_case
434+
):
432435
exists = await pipe.exists(key)
433436
if upsert_case:
434437
if exists:
435-
await pipe.json().set(key, "$.channel", write_obj["channel"])
438+
await pipe.json().set(
439+
key, "$.channel", write_obj["channel"]
440+
)
436441
await pipe.json().set(key, "$.type", write_obj["type"])
437442
await pipe.json().set(key, "$.blob", write_obj["blob"])
438443
else:
439444
await pipe.json().set(key, "$", write_obj)
440445
else:
441446
if not exists:
442447
await pipe.json().set(key, "$", write_obj)
448+
443449
await self._redis.transaction(tx, key)
444450

445451
def put_writes(
@@ -533,18 +539,12 @@ async def from_conn_string(
533539
redis_client: Optional[AsyncRedis] = None,
534540
connection_args: Optional[dict[str, Any]] = None,
535541
) -> AsyncIterator[AsyncRedisSaver]:
536-
saver: Optional[AsyncRedisSaver] = None
537-
try:
538-
saver = cls(
539-
redis_url=redis_url,
540-
redis_client=redis_client,
541-
connection_args=connection_args,
542-
)
542+
async with cls(
543+
redis_url=redis_url,
544+
redis_client=redis_client,
545+
connection_args=connection_args,
546+
) as saver:
543547
yield saver
544-
finally:
545-
if saver and saver._owns_its_client: # Ensure saver is not None
546-
await saver._redis.aclose() # type: ignore[attr-defined]
547-
await saver._redis.connection_pool.disconnect()
548548

549549
async def aget_channel_values(
550550
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""

langgraph/checkpoint/redis/ashallow.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast
99

1010
from langchain_core.runnables import RunnableConfig
11-
from redis import WatchError
1211
from redisvl.index import AsyncSearchIndex
1312
from redisvl.query import FilterQuery
1413
from redisvl.query.filter import Num, Tag
@@ -100,9 +99,22 @@ def __init__(
10099
redis_client=redis_client,
101100
connection_args=connection_args,
102101
)
103-
# self.lock = asyncio.Lock()
104102
self.loop = asyncio.get_running_loop()
105103

104+
async def __aenter__(self) -> AsyncShallowRedisSaver:
105+
return self
106+
107+
async def __aexit__(self, exc_type, exc, tb) -> None:
108+
if self._owns_its_client:
109+
await self._redis.aclose() # type: ignore[attr-defined]
110+
await self._redis.connection_pool.disconnect()
111+
112+
# Prevent RedisVL from attempting to close the client
113+
# on an event loop in a separate thread.
114+
self.checkpoints_index._redis_client = None
115+
self.checkpoint_blobs_index._redis_client = None
116+
self.checkpoint_writes_index._redis_client = None
117+
106118
@classmethod
107119
@asynccontextmanager
108120
async def from_conn_string(
@@ -113,18 +125,12 @@ async def from_conn_string(
113125
connection_args: Optional[dict[str, Any]] = None,
114126
) -> AsyncIterator[AsyncShallowRedisSaver]:
115127
"""Create a new AsyncShallowRedisSaver instance."""
116-
saver: Optional[AsyncShallowRedisSaver] = None
117-
try:
118-
saver = cls(
119-
redis_url=redis_url,
120-
redis_client=redis_client,
121-
connection_args=connection_args,
122-
)
128+
async with cls(
129+
redis_url=redis_url,
130+
redis_client=redis_client,
131+
connection_args=connection_args,
132+
) as saver:
123133
yield saver
124-
finally:
125-
if saver and saver._owns_its_client:
126-
await saver._redis.aclose() # type: ignore[attr-defined]
127-
await saver._redis.connection_pool.disconnect()
128134

129135
async def asetup(self) -> None:
130136
"""Initialize Redis indexes asynchronously."""
@@ -397,6 +403,7 @@ async def aput_writes(
397403
write_obj["idx"],
398404
)
399405
if upsert_case:
406+
400407
async def tx(pipe, key=key, write_obj=write_obj):
401408
exists = await pipe.exists(key)
402409
if exists:

langgraph/store/redis/aio.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def create_indexes(self) -> None:
194194

195195
async def __aenter__(self) -> AsyncRedisStore:
196196
"""Async context manager enter."""
197+
await self.setup()
197198
return self
198199

199200
async def __aexit__(
@@ -291,26 +292,11 @@ async def _batch_get_ops(
291292
) -> None:
292293
"""Execute GET operations in batch asynchronously."""
293294
for query, _, namespace, items in self._get_batch_GET_ops_queries(get_ops):
294-
# Use RedisVL AsyncSearchIndex search
295-
search_query = FilterQuery(
296-
filter_expression=query,
297-
return_fields=["id"], # Just need the document id
298-
num_results=len(items),
299-
)
300-
res = await self.store_index.search(search_query)
301-
302-
# Use pipeline to get the actual JSON documents
303-
pipeline = self._redis.pipeline(transaction=False)
304-
doc_ids = []
305-
for doc in res.docs:
306-
# The id is already in the correct format (store:prefix:key)
307-
pipeline.json().get(doc.id)
308-
doc_ids.append(doc.id)
309-
310-
json_docs = await pipeline.execute()
311-
312-
# Convert to dictionary format
313-
key_to_row = {doc["key"]: doc for doc in json_docs if doc}
295+
res = await self.store_index.search(Query(query))
296+
# Parse JSON from each document
297+
key_to_row = {
298+
json.loads(doc.json)["key"]: json.loads(doc.json) for doc in res.docs
299+
}
314300

315301
for idx, key in items:
316302
if key in key_to_row:

tests/test_async_store.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,11 @@ async def test_async_store_with_memory_persistence(
501501
"distance_type": "cosine",
502502
}
503503

504-
async with AsyncRedisStore.from_conn_string(redis_url, index=index_config) as store:
504+
async with AsyncRedisStore.from_conn_string(
505+
redis_url, index=index_config
506+
) as store, AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
505507
await store.setup()
508+
await checkpointer.asetup()
506509

507510
model = ChatAnthropic(model="claude-3-5-sonnet-20240620") # type: ignore[call-arg]
508511

@@ -532,11 +535,6 @@ def call_model(
532535
builder.add_node("call_model", call_model) # type:ignore[arg-type]
533536
builder.add_edge(START, "call_model")
534537

535-
checkpointer = None
536-
async with AsyncRedisSaver.from_conn_string(redis_url) as cp:
537-
await cp.asetup()
538-
checkpointer = cp
539-
540538
# Compile graph with store and checkpointer
541539
graph = builder.compile(checkpointer=checkpointer, store=store)
542540

0 commit comments

Comments
 (0)