Skip to content

Commit 039f259

Browse files
committed
Refactor batch search and add batch_query
1 parent 64a1ba9 commit 039f259

File tree

3 files changed

+223
-89
lines changed

3 files changed

+223
-89
lines changed

redisvl/index/index.py

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Iterable,
1515
List,
1616
Optional,
17+
Tuple,
1718
Union,
1819
)
1920

@@ -51,6 +52,14 @@
5152
{"name": "searchlight", "ver": 20810},
5253
]
5354

55+
SearchParams = Union[
56+
Tuple[
57+
Union[str, BaseQuery],
58+
Union[Dict[str, Union[str, int, float, bytes]], None],
59+
],
60+
Union[str, BaseQuery],
61+
]
62+
5463

5564
def process_results(
5665
results: "Result", query: BaseQuery, storage_type: StorageType
@@ -656,22 +665,23 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult":
656665
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
657666

658667
def batch_search(
659-
self, queries: List[str], batch_size: int = 100, **query_params
660-
) -> List[List[Dict[str, Any]]]:
668+
self,
669+
queries: List[SearchParams],
670+
batch_size: int = 10,
671+
) -> List["Result"]:
661672
"""Perform a search against the index for multiple queries.
662673
663-
This method takes a list of queries and returns a list of search results.
664-
The results are returned in the same order as the queries.
674+
This method takes a list of queries and optionally query params and
675+
returns a list of Result objects for each query. Results are
676+
returned in the same order as the queries.
665677
666678
Args:
667-
queries (List[str]): The queries to search for.
668-
batch_size (int, optional): The number of queries to search for at a time.
669-
Defaults to 100.
670-
query_params (dict, optional): The query parameters to pass to the search
671-
for each query.
679+
queries (List[SearchParams]): The queries to search for. batch_size
680+
(int, optional): The number of queries to search for at a time.
681+
Defaults to 10.
672682
673683
Returns:
674-
List[List[Dict[str, Any]]]: The search results.
684+
List[Result]: The search results for each query.
675685
"""
676686
all_parsed = []
677687
search = self._redis_client.ft(self.schema.index.name)
@@ -688,9 +698,14 @@ def batch_search(
688698
with self._redis_client.pipeline(transaction=False) as pipe:
689699
batch_built_queries = []
690700
for query in batch_queries:
691-
query_args, q = search._mk_query_args( # type: ignore
692-
query, query_params=query_params
693-
)
701+
if isinstance(query, tuple):
702+
query_args, q = search._mk_query_args( # type: ignore
703+
query[0], query_params=query[1]
704+
)
705+
else:
706+
query_args, q = search._mk_query_args( # type: ignore
707+
query, query_params=None
708+
)
694709
batch_built_queries.append(q)
695710
pipe.execute_command(
696711
"FT.SEARCH",
@@ -707,20 +722,13 @@ def batch_search(
707722

708723
for i, query_results in enumerate(results):
709724
_built_query = batch_built_queries[i]
710-
parsed_raw = search._parse_search( # type: ignore
725+
parsed_result = search._parse_search( # type: ignore
711726
query_results,
712727
query=_built_query,
713728
duration=duration,
714729
)
715-
parsed = process_results(
716-
parsed_raw,
717-
query=_built_query,
718-
storage_type=self.schema.index.storage_type,
719-
)
720-
# Create separate lists of parsed results for each query
721-
# passed in to the batch_search method, so that callers can
722-
# access the results for each query individually
723-
all_parsed.append(parsed)
730+
# Return a parsed Result object for each query
731+
all_parsed.append(parsed_result)
724732
return all_parsed
725733

726734
def search(self, *args, **kwargs) -> "Result":
@@ -740,6 +748,26 @@ def search(self, *args, **kwargs) -> "Result":
740748
except Exception as e:
741749
raise RedisSearchError(f"Error while searching: {str(e)}") from e
742750

751+
def batch_query(
752+
self, queries: List[BaseQuery], batch_size: int = 10
753+
) -> List[List[Dict[str, Any]]]:
754+
"""Execute a batch of queries and process results."""
755+
results = self.batch_search(
756+
[(query.query, query.params) for query in queries], batch_size=batch_size
757+
)
758+
all_parsed = []
759+
for query, batch_results in zip(queries, results):
760+
parsed = process_results(
761+
batch_results,
762+
query=query,
763+
storage_type=self.schema.index.storage_type,
764+
)
765+
# Create separate lists of parsed results for each query
766+
# passed in to the batch_search method, so that callers can
767+
# access the results for each query individually
768+
all_parsed.append(parsed)
769+
return all_parsed
770+
743771
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
744772
"""Execute a query and process results."""
745773
results = self.search(query.query, query_params=query.params)
@@ -1283,22 +1311,20 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
12831311
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
12841312

12851313
async def batch_search(
1286-
self, queries: List[BaseQuery], batch_size: int = 100, **query_params
1314+
self, queries: List[SearchParams], batch_size: int = 10
12871315
) -> List["Result"]:
12881316
"""Perform a search against the index for multiple queries.
12891317
1290-
This method takes a list of queries and returns a list of search results.
1291-
The results are returned in the same order as the queries.
1318+
This method takes a list of queries and returns a list of Result objects
1319+
for each query. Results are returned in the same order as the queries.
12921320
12931321
Args:
1294-
queries (List[str]): The queries to search for.
1295-
batch_size (int, optional): The number of queries to search for at a time.
1296-
Defaults to 100.
1297-
query_params (dict, optional): The query parameters to pass to the search
1298-
for each query.
1322+
queries (List[SearchParams]): The queries to search for. batch_size
1323+
(int, optional): The number of queries to search for at a time.
1324+
Defaults to 10.
12991325
13001326
Returns:
1301-
List[Result]: The search results.
1327+
List[Result]: The search results for each query.
13021328
"""
13031329
all_results = []
13041330
client = await self._get_client()
@@ -1316,9 +1342,14 @@ async def batch_search(
13161342
async with client.pipeline(transaction=False) as pipe:
13171343
batch_built_queries = []
13181344
for query in batch_queries:
1319-
query_args, q = search._mk_query_args( # type: ignore
1320-
query, query_params=query_params
1321-
)
1345+
if isinstance(query, tuple):
1346+
query_args, q = search._mk_query_args( # type: ignore
1347+
query[0], query_params=query[1]
1348+
)
1349+
else:
1350+
query_args, q = search._mk_query_args( # type: ignore
1351+
query, query_params=None
1352+
)
13221353
batch_built_queries.append(q)
13231354
pipe.execute_command(
13241355
"FT.SEARCH",
@@ -1335,12 +1366,13 @@ async def batch_search(
13351366

13361367
for i, query_results in enumerate(results):
13371368
_built_query = batch_built_queries[i]
1338-
parsed_raw = search._parse_search( # type: ignore
1369+
parsed_result = search._parse_search( # type: ignore
13391370
query_results,
13401371
query=_built_query,
13411372
duration=duration,
13421373
)
1343-
all_results.append(parsed_raw)
1374+
# Return a parsed Result object for each query
1375+
all_results.append(parsed_result)
13441376
return all_results
13451377

13461378
async def search(self, *args, **kwargs) -> "Result":
@@ -1359,11 +1391,13 @@ async def search(self, *args, **kwargs) -> "Result":
13591391
except Exception as e:
13601392
raise RedisSearchError(f"Error while searching: {str(e)}") from e
13611393

1362-
async def _batch_query(
1363-
self, queries: List[BaseQuery], batch_size: int = 100
1394+
async def batch_query(
1395+
self, queries: List[BaseQuery], batch_size: int = 10
13641396
) -> List[List[Dict[str, Any]]]:
13651397
"""Asynchronously execute a batch of queries and process results."""
1366-
results = await self.batch_search(queries, batch_size=batch_size)
1398+
results = await self.batch_search(
1399+
[(query.query, query.params) for query in queries], batch_size=batch_size
1400+
)
13671401
all_parsed = []
13681402
for query, batch_results in zip(queries, results):
13691403
parsed = process_results(
@@ -1378,15 +1412,6 @@ async def _batch_query(
13781412

13791413
return all_parsed
13801414

1381-
async def batch_query(
1382-
self, queries: List[BaseQuery], batch_size: int = 100
1383-
) -> List[List[Dict[str, Any]]]:
1384-
"""Asynchronously execute a batch of queries and process results."""
1385-
return await self._batch_query(
1386-
[query.query for query in queries],
1387-
batch_size=batch_size,
1388-
)
1389-
13901415
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
13911416
"""Asynchronously execute a query and process results."""
13921417
results = await self.search(query.query, query_params=query.params)

tests/integration/test_async_search_index.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -447,47 +447,81 @@ async def test_batch_search(async_index):
447447

448448
results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"])
449449
assert len(results) == 2
450-
assert results[0][0]["id"] == "rvl:1"
451-
assert results[1][0]["id"] == "rvl:2"
450+
assert results[0].total == 1
451+
assert results[0].docs[0]["id"] == "rvl:1"
452+
assert results[1].total == 1
453+
assert results[1].docs[0]["id"] == "rvl:2"
452454

453455

456+
@pytest.mark.parametrize(
457+
"queries",
458+
[
459+
[
460+
[
461+
FilterQuery(filter_expression="@test:{foo}"),
462+
FilterQuery(filter_expression="@test:{bar}"),
463+
],
464+
[
465+
FilterQuery(filter_expression="@test:{foo}"),
466+
FilterQuery(filter_expression="@test:{bar}"),
467+
FilterQuery(filter_expression="@test:{baz}"),
468+
FilterQuery(filter_expression="@test:{foo}"),
469+
FilterQuery(filter_expression="@test:{bar}"),
470+
FilterQuery(filter_expression="@test:{baz}"),
471+
],
472+
],
473+
[
474+
[
475+
"@test:{foo}",
476+
"@test:{bar}",
477+
],
478+
[
479+
"@test:{foo}",
480+
"@test:{bar}",
481+
"@test:{baz}",
482+
"@test:{foo}",
483+
"@test:{bar}",
484+
"@test:{baz}",
485+
],
486+
],
487+
],
488+
)
454489
@pytest.mark.asyncio
455-
async def test_batch_search_with_multiple_batches(async_index):
490+
async def test_batch_search_with_multiple_batches(async_index, queries):
456491
await async_index.create(overwrite=True, drop=True)
457492
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
458493
await async_index.load(data, id_field="id")
459494

460-
results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"])
495+
results = await async_index.batch_search(queries[0])
461496
assert len(results) == 2
462-
assert len(results[0]) == 1
463-
assert len(results[1]) == 1
497+
assert results[0].total == 1
498+
assert results[0].docs[0]["id"] == "rvl:1"
499+
assert results[1].total == 1
500+
assert results[1].docs[0]["id"] == "rvl:2"
464501

465502
results = await async_index.batch_search(
466-
[
467-
"@test:{foo}",
468-
"@test:{bar}",
469-
"@test:{baz}",
470-
"@test:{foo}",
471-
"@test:{bar}",
472-
"@test:{baz}",
473-
],
503+
queries[1],
474504
batch_size=2,
475505
)
476506
assert len(results) == 6
477507

478508
# First (and only) result for the first query
479-
assert results[0][0]["id"] == "rvl:1"
509+
assert results[0].total == 1
510+
assert results[0].docs[0]["id"] == "rvl:1"
480511

481512
# Second (and only) result for the second query
482-
assert results[1][0]["id"] == "rvl:2"
513+
assert results[1].total == 1
514+
assert results[1].docs[0]["id"] == "rvl:2"
483515

484516
# Third query should have zero results because there is no baz
485-
assert len(results[2]) == 0
517+
assert results[2].total == 0
486518

487519
# Then the pattern repeats
488-
assert results[3][0]["id"] == "rvl:1"
489-
assert results[4][0]["id"] == "rvl:2"
490-
assert len(results[5]) == 0
520+
assert results[3].total == 1
521+
assert results[3].docs[0]["id"] == "rvl:1"
522+
assert results[4].total == 1
523+
assert results[4].docs[0]["id"] == "rvl:2"
524+
assert results[5].total == 0
491525

492526

493527
@pytest.mark.asyncio
@@ -500,3 +534,20 @@ async def test_batch_query(async_index):
500534
results = await async_index.batch_query([query])
501535

502536
assert len(results) == 1
537+
assert results[0][0]["id"] == "rvl:1"
538+
539+
540+
@pytest.mark.asyncio
541+
async def test_batch_query_with_multiple_batches(async_index):
542+
await async_index.create(overwrite=True, drop=True)
543+
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
544+
await async_index.load(data, id_field="id")
545+
546+
queries = [
547+
FilterQuery(filter_expression="@test:{foo}"),
548+
FilterQuery(filter_expression="@test:{bar}"),
549+
]
550+
results = await async_index.batch_query(queries, batch_size=1)
551+
assert len(results) == 2
552+
assert results[0][0]["id"] == "rvl:1"
553+
assert results[1][0]["id"] == "rvl:2"

0 commit comments

Comments
 (0)