Skip to content

Commit 980f51a

Browse files
committed
Add async batch_search
1 parent 9b88f04 commit 980f51a

File tree

3 files changed

+121
-1
lines changed

3 files changed

+121
-1
lines changed

redisvl/index/index.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,75 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
12821282
except Exception as e:
12831283
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
12841284

1285+
async def batch_search(
1286+
self, queries: List[str], batch_size: int = 100, **query_params
1287+
) -> List[List[Dict[str, Any]]]:
1288+
"""Perform a search against the index for multiple queries.
1289+
1290+
This method takes a list of queries and returns a list of search results.
1291+
The results are returned in the same order as the queries.
1292+
1293+
Args:
1294+
queries (List[str]): The queries to search for.
1295+
batch_size (int, optional): The number of queries to search for at a time.
1296+
Defaults to 100.
1297+
query_params (dict, optional): The query parameters to pass to the search
1298+
for each query.
1299+
1300+
Returns:
1301+
List[List[Dict[str, Any]]]: The search results.
1302+
"""
1303+
all_parsed = []
1304+
client = await self._get_client()
1305+
search = client.ft(self.schema.index.name)
1306+
options = {}
1307+
if get_protocol_version(client) not in ["3", 3]:
1308+
options[NEVER_DECODE] = True
1309+
1310+
for i in range(0, len(queries), batch_size):
1311+
batch_queries = queries[i : i + batch_size]
1312+
1313+
# redis-py doesn't support calling `search` in a pipeline,
1314+
# so we need to manually execute each command in a pipeline
1315+
# and parse the results
1316+
async with client.pipeline(transaction=False) as pipe:
1317+
batch_built_queries = []
1318+
for query in batch_queries:
1319+
query_args, q = search._mk_query_args( # type: ignore
1320+
query, query_params=query_params
1321+
)
1322+
batch_built_queries.append(q)
1323+
pipe.execute_command(
1324+
"FT.SEARCH",
1325+
*query_args,
1326+
**options,
1327+
)
1328+
1329+
st = time.time()
1330+
results = await pipe.execute()
1331+
1332+
# We don't know how long each query took, so we'll use the total time
1333+
# for all queries in the batch as the duration for each query
1334+
duration = (time.time() - st) * 1000.0
1335+
1336+
for i, query_results in enumerate(results):
1337+
_built_query = batch_built_queries[i]
1338+
parsed_raw = search._parse_search( # type: ignore
1339+
query_results,
1340+
query=_built_query,
1341+
duration=duration,
1342+
)
1343+
parsed = process_results(
1344+
parsed_raw,
1345+
query=_built_query,
1346+
storage_type=self.schema.index.storage_type,
1347+
)
1348+
# Create separate lists of parsed results for each query
1349+
# passed in to the batch_search method, so that callers can
1350+
# access the results for each query individually
1351+
all_parsed.append(parsed)
1352+
return all_parsed
1353+
12851354
async def search(self, *args, **kwargs) -> "Result":
12861355
"""Perform a search on this index.
12871356

tests/integration/test_async_search_index.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,54 @@ async def test_async_search_index_validates_redis_modules(redis_url):
436436
await index.create(overwrite=True, drop=True)
437437

438438
mock_validate_async_redis.assert_called_once()
439+
440+
441+
@pytest.mark.asyncio
442+
async def test_batch_search(async_index):
443+
await async_index.create(overwrite=True, drop=True)
444+
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
445+
await async_index.load(data, id_field="id")
446+
447+
results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"])
448+
assert len(results) == 2
449+
assert results[0][0]["id"] == "rvl:1"
450+
assert results[1][0]["id"] == "rvl:2"
451+
452+
453+
@pytest.mark.asyncio
454+
async def test_batch_search_with_multiple_batches(async_index):
455+
await async_index.create(overwrite=True, drop=True)
456+
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
457+
await async_index.load(data, id_field="id")
458+
459+
results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"])
460+
assert len(results) == 2
461+
assert len(results[0]) == 1
462+
assert len(results[1]) == 1
463+
464+
results = await async_index.batch_search(
465+
[
466+
"@test:{foo}",
467+
"@test:{bar}",
468+
"@test:{baz}",
469+
"@test:{foo}",
470+
"@test:{bar}",
471+
"@test:{baz}",
472+
],
473+
batch_size=2,
474+
)
475+
assert len(results) == 6
476+
477+
# First (and only) result for the first query
478+
assert results[0][0]["id"] == "rvl:1"
479+
480+
# Second (and only) result for the second query
481+
assert results[1][0]["id"] == "rvl:2"
482+
483+
# Third query should have zero results because there is no baz
484+
assert len(results[2]) == 0
485+
486+
# Then the pattern repeats
487+
assert results[3][0]["id"] == "rvl:1"
488+
assert results[4][0]["id"] == "rvl:2"
489+
assert len(results[5]) == 0

tests/integration/test_search_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def test_batch_search_with_multiple_batches(index):
431431
# Second (and only) result for the second query
432432
assert results[1][0]["id"] == "rvl:2"
433433

434-
# Third query has no results
434+
# Third query should have zero results because there is no baz
435435
assert len(results[2]) == 0
436436

437437
# Then the pattern repeats

0 commit comments

Comments
 (0)