Skip to content

Commit 64a1ba9

Browse files
committed
WIP on async batch_query
1 parent 980f51a commit 64a1ba9

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

redisvl/index/index.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,8 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
12831283
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
12841284

12851285
async def batch_search(
1286-
self, queries: List[str], batch_size: int = 100, **query_params
1287-
) -> List[List[Dict[str, Any]]]:
1286+
self, queries: List[BaseQuery], batch_size: int = 100, **query_params
1287+
) -> List["Result"]:
12881288
"""Perform a search against the index for multiple queries.
12891289
12901290
This method takes a list of queries and returns a list of search results.
@@ -1298,9 +1298,9 @@ async def batch_search(
12981298
for each query.
12991299
13001300
Returns:
1301-
List[List[Dict[str, Any]]]: The search results.
1301+
List[Result]: The search results.
13021302
"""
1303-
all_parsed = []
1303+
all_results = []
13041304
client = await self._get_client()
13051305
search = client.ft(self.schema.index.name)
13061306
options = {}
@@ -1340,16 +1340,8 @@ async def batch_search(
13401340
query=_built_query,
13411341
duration=duration,
13421342
)
1343-
parsed = process_results(
1344-
parsed_raw,
1345-
query=_built_query,
1346-
storage_type=self.schema.index.storage_type,
1347-
)
1348-
# Create separate lists of parsed results for each query
1349-
# passed in to the batch_search method, so that callers can
1350-
# access the results for each query individually
1351-
all_parsed.append(parsed)
1352-
return all_parsed
1343+
all_results.append(parsed_raw)
1344+
return all_results
13531345

13541346
async def search(self, *args, **kwargs) -> "Result":
13551347
"""Perform a search on this index.
@@ -1367,6 +1359,34 @@ async def search(self, *args, **kwargs) -> "Result":
13671359
except Exception as e:
13681360
raise RedisSearchError(f"Error while searching: {str(e)}") from e
13691361

1362+
async def _batch_query(
1363+
self, queries: List[BaseQuery], batch_size: int = 100
1364+
) -> List[List[Dict[str, Any]]]:
1365+
"""Asynchronously execute a batch of queries and process results."""
1366+
results = await self.batch_search(queries, batch_size=batch_size)
1367+
all_parsed = []
1368+
for query, batch_results in zip(queries, results):
1369+
parsed = process_results(
1370+
batch_results,
1371+
query=query,
1372+
storage_type=self.schema.index.storage_type,
1373+
)
1374+
# Create separate lists of parsed results for each query
1375+
# passed in to the batch_search method, so that callers can
1376+
# access the results for each query individually
1377+
all_parsed.append(parsed)
1378+
1379+
return all_parsed
1380+
1381+
async def batch_query(
1382+
self, queries: List[BaseQuery], batch_size: int = 100
1383+
) -> List[List[Dict[str, Any]]]:
1384+
"""Asynchronously execute a batch of queries and process results."""
1385+
return await self._batch_query(
1386+
[query.query for query in queries],
1387+
batch_size=batch_size,
1388+
)
1389+
13701390
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
13711391
"""Asynchronously execute a query and process results."""
13721392
results = await self.search(query.query, query_params=query.params)

tests/integration/test_async_search_index.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
99
from redisvl.index import AsyncSearchIndex
1010
from redisvl.query import VectorQuery
11+
from redisvl.query.query import FilterQuery
1112
from redisvl.redis.utils import convert_bytes
1213
from redisvl.schema import IndexSchema, StorageType
1314

@@ -487,3 +488,15 @@ async def test_batch_search_with_multiple_batches(async_index):
487488
assert results[3][0]["id"] == "rvl:1"
488489
assert results[4][0]["id"] == "rvl:2"
489490
assert len(results[5]) == 0
491+
492+
493+
@pytest.mark.asyncio
494+
async def test_batch_query(async_index):
495+
await async_index.create(overwrite=True, drop=True)
496+
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
497+
await async_index.load(data, id_field="id")
498+
499+
query = FilterQuery(filter_expression="@test:{foo}")
500+
results = await async_index.batch_query([query])
501+
502+
assert len(results) == 1

0 commit comments

Comments
 (0)