Skip to content

Commit bcec7fc

Browse files
committed
Avoid using "transactions" when they are not needed
1 parent 1cb4c0e commit bcec7fc

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

langgraph/store/redis/aio.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,12 @@ async def from_conn_string(
179179
index: Optional[IndexConfig] = None,
180180
) -> AsyncIterator[AsyncRedisStore]:
181181
"""Create store from Redis connection string."""
182-
store = cls(redis_url=conn_string, index=index)
183-
try:
182+
async with cls(redis_url=conn_string, index=index) as store:
184183
store._task = store.loop.create_task(
185184
store._run_background_tasks(store._aqueue, weakref.ref(store))
186185
)
187186
await store.setup()
188187
yield store
189-
finally:
190-
if hasattr(store, "_task"):
191-
store._task.cancel()
192-
try:
193-
await store._task
194-
except asyncio.CancelledError:
195-
pass
196-
if store._owns_client:
197-
await store._redis.aclose() # type: ignore[attr-defined]
198-
await store._redis.connection_pool.disconnect()
199188

200189
def create_indexes(self) -> None:
201190
"""Create async indices."""
@@ -221,8 +210,9 @@ async def __aexit__(
221210
except asyncio.CancelledError:
222211
pass
223212

224-
# if self._owns_client:
225-
await self._redis.aclose() # type: ignore[attr-defined]
213+
if self._owns_client:
214+
await self._redis.aclose() # type: ignore[attr-defined]
215+
await self._redis.connection_pool.disconnect()
226216

227217
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
228218
"""Execute batch of operations asynchronously."""
@@ -301,11 +291,27 @@ async def _batch_get_ops(
301291
) -> None:
302292
"""Execute GET operations in batch asynchronously."""
303293
for query, _, namespace, items in self._get_batch_GET_ops_queries(get_ops):
304-
res = await self.store_index.search(Query(query))
305-
# Parse JSON from each document
306-
key_to_row = {
307-
json.loads(doc.json)["key"]: json.loads(doc.json) for doc in res.docs
308-
}
294+
# Use RedisVL AsyncSearchIndex search
295+
search_query = FilterQuery(
296+
filter_expression=query,
297+
return_fields=["id"], # Just need the document id
298+
num_results=len(items),
299+
)
300+
res = await self.store_index.search(search_query)
301+
302+
# Use pipeline to get the actual JSON documents
303+
pipeline = self._redis.pipeline(transaction=False)
304+
doc_ids = []
305+
for doc in res.docs:
306+
# The id is already in the correct format (store:prefix:key)
307+
pipeline.json().get(doc.id)
308+
doc_ids.append(doc.id)
309+
310+
json_docs = await pipeline.execute()
311+
312+
# Convert to dictionary format
313+
key_to_row = {doc["key"]: doc for doc in json_docs if doc}
314+
309315
for idx, key in items:
310316
if key in key_to_row:
311317
results[idx] = _row_to_item(namespace, key_to_row[key])
@@ -482,7 +488,7 @@ async def _batch_search_ops(
482488
)
483489

484490
# Get matching store docs in pipeline
485-
pipeline = self._redis.pipeline()
491+
pipeline = self._redis.pipeline(transaction=False)
486492
result_map = {} # Map store key to vector result with distances
487493

488494
for doc in vector_results:

0 commit comments

Comments
 (0)