1414 Iterable ,
1515 List ,
1616 Optional ,
17+ Tuple ,
1718 Union ,
1819)
1920
5152 {"name" : "searchlight" , "ver" : 20810 },
5253]
5354
55+ SearchParams = Union [
56+ Tuple [
57+ Union [str , BaseQuery ],
58+ Union [Dict [str , Union [str , int , float , bytes ]], None ],
59+ ],
60+ Union [str , BaseQuery ],
61+ ]
62+
5463
5564def process_results (
5665 results : "Result" , query : BaseQuery , storage_type : StorageType
@@ -656,22 +665,23 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult":
656665 raise RedisSearchError (f"Error while aggregating: { str (e )} " ) from e
657666
658667 def batch_search (
659- self , queries : List [str ], batch_size : int = 100 , ** query_params
660- ) -> List [List [Dict [str , Any ]]]:
668+ self ,
669+ queries : List [SearchParams ],
670+ batch_size : int = 10 ,
671+ ) -> List ["Result" ]:
661672 """Perform a search against the index for multiple queries.
662673
663- This method takes a list of queries and returns a list of search results.
664- The results are returned in the same order as the queries.
674+ This method takes a list of queries and optionally query params and
675+ returns a list of Result objects for each query. Results are
676+ returned in the same order as the queries.
665677
666678 Args:
667- queries (List[str]): The queries to search for.
668- batch_size (int, optional): The number of queries to search for at a time.
669- Defaults to 100.
670- query_params (dict, optional): The query parameters to pass to the search
671- for each query.
679+ queries (List[SearchParams]): The queries to search for. batch_size
680+ (int, optional): The number of queries to search for at a time.
681+ Defaults to 10.
672682
673683 Returns:
674- List[List[Dict[str, Any]]] : The search results.
684+ List[Result] : The search results for each query .
675685 """
676686 all_parsed = []
677687 search = self ._redis_client .ft (self .schema .index .name )
@@ -688,9 +698,14 @@ def batch_search(
688698 with self ._redis_client .pipeline (transaction = False ) as pipe :
689699 batch_built_queries = []
690700 for query in batch_queries :
691- query_args , q = search ._mk_query_args ( # type: ignore
692- query , query_params = query_params
693- )
701+ if isinstance (query , tuple ):
702+ query_args , q = search ._mk_query_args ( # type: ignore
703+ query [0 ], query_params = query [1 ]
704+ )
705+ else :
706+ query_args , q = search ._mk_query_args ( # type: ignore
707+ query , query_params = None
708+ )
694709 batch_built_queries .append (q )
695710 pipe .execute_command (
696711 "FT.SEARCH" ,
@@ -707,20 +722,13 @@ def batch_search(
707722
708723 for i , query_results in enumerate (results ):
709724 _built_query = batch_built_queries [i ]
710- parsed_raw = search ._parse_search ( # type: ignore
725+ parsed_result = search ._parse_search ( # type: ignore
711726 query_results ,
712727 query = _built_query ,
713728 duration = duration ,
714729 )
715- parsed = process_results (
716- parsed_raw ,
717- query = _built_query ,
718- storage_type = self .schema .index .storage_type ,
719- )
720- # Create separate lists of parsed results for each query
721- # passed in to the batch_search method, so that callers can
722- # access the results for each query individually
723- all_parsed .append (parsed )
730+ # Return a parsed Result object for each query
731+ all_parsed .append (parsed_result )
724732 return all_parsed
725733
726734 def search (self , * args , ** kwargs ) -> "Result" :
@@ -740,6 +748,26 @@ def search(self, *args, **kwargs) -> "Result":
740748 except Exception as e :
741749 raise RedisSearchError (f"Error while searching: { str (e )} " ) from e
742750
751+ def batch_query (
752+ self , queries : List [BaseQuery ], batch_size : int = 10
753+ ) -> List [List [Dict [str , Any ]]]:
754+ """Execute a batch of queries and process results."""
755+ results = self .batch_search (
756+ [(query .query , query .params ) for query in queries ], batch_size = batch_size
757+ )
758+ all_parsed = []
759+ for query , batch_results in zip (queries , results ):
760+ parsed = process_results (
761+ batch_results ,
762+ query = query ,
763+ storage_type = self .schema .index .storage_type ,
764+ )
765+ # Create separate lists of parsed results for each query
766+ # passed in to the batch_search method, so that callers can
767+ # access the results for each query individually
768+ all_parsed .append (parsed )
769+ return all_parsed
770+
743771 def _query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
744772 """Execute a query and process results."""
745773 results = self .search (query .query , query_params = query .params )
@@ -1283,22 +1311,20 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
12831311 raise RedisSearchError (f"Error while aggregating: { str (e )} " ) from e
12841312
12851313 async def batch_search (
1286- self , queries : List [BaseQuery ], batch_size : int = 100 , ** query_params
1314+ self , queries : List [SearchParams ], batch_size : int = 10
12871315 ) -> List ["Result" ]:
12881316 """Perform a search against the index for multiple queries.
12891317
1290- This method takes a list of queries and returns a list of search results.
1291- The results are returned in the same order as the queries.
1318+ This method takes a list of queries and returns a list of Result objects
1319+ for each query. Results are returned in the same order as the queries.
12921320
12931321 Args:
1294- queries (List[str]): The queries to search for.
1295- batch_size (int, optional): The number of queries to search for at a time.
1296- Defaults to 100.
1297- query_params (dict, optional): The query parameters to pass to the search
1298- for each query.
1322+ queries (List[SearchParams]): The queries to search for. batch_size
1323+ (int, optional): The number of queries to search for at a time.
1324+ Defaults to 10.
12991325
13001326 Returns:
1301- List[Result]: The search results.
1327+ List[Result]: The search results for each query .
13021328 """
13031329 all_results = []
13041330 client = await self ._get_client ()
@@ -1316,9 +1342,14 @@ async def batch_search(
13161342 async with client .pipeline (transaction = False ) as pipe :
13171343 batch_built_queries = []
13181344 for query in batch_queries :
1319- query_args , q = search ._mk_query_args ( # type: ignore
1320- query , query_params = query_params
1321- )
1345+ if isinstance (query , tuple ):
1346+ query_args , q = search ._mk_query_args ( # type: ignore
1347+ query [0 ], query_params = query [1 ]
1348+ )
1349+ else :
1350+ query_args , q = search ._mk_query_args ( # type: ignore
1351+ query , query_params = None
1352+ )
13221353 batch_built_queries .append (q )
13231354 pipe .execute_command (
13241355 "FT.SEARCH" ,
@@ -1335,12 +1366,13 @@ async def batch_search(
13351366
13361367 for i , query_results in enumerate (results ):
13371368 _built_query = batch_built_queries [i ]
1338- parsed_raw = search ._parse_search ( # type: ignore
1369+ parsed_result = search ._parse_search ( # type: ignore
13391370 query_results ,
13401371 query = _built_query ,
13411372 duration = duration ,
13421373 )
1343- all_results .append (parsed_raw )
1374+ # Return a parsed Result object for each query
1375+ all_results .append (parsed_result )
13441376 return all_results
13451377
13461378 async def search (self , * args , ** kwargs ) -> "Result" :
@@ -1359,11 +1391,13 @@ async def search(self, *args, **kwargs) -> "Result":
13591391 except Exception as e :
13601392 raise RedisSearchError (f"Error while searching: { str (e )} " ) from e
13611393
1362- async def _batch_query (
1363- self , queries : List [BaseQuery ], batch_size : int = 100
1394+ async def batch_query (
1395+ self , queries : List [BaseQuery ], batch_size : int = 10
13641396 ) -> List [List [Dict [str , Any ]]]:
13651397 """Asynchronously execute a batch of queries and process results."""
1366- results = await self .batch_search (queries , batch_size = batch_size )
1398+ results = await self .batch_search (
1399+ [(query .query , query .params ) for query in queries ], batch_size = batch_size
1400+ )
13671401 all_parsed = []
13681402 for query , batch_results in zip (queries , results ):
13691403 parsed = process_results (
@@ -1378,15 +1412,6 @@ async def _batch_query(
13781412
13791413 return all_parsed
13801414
1381- async def batch_query (
1382- self , queries : List [BaseQuery ], batch_size : int = 100
1383- ) -> List [List [Dict [str , Any ]]]:
1384- """Asynchronously execute a batch of queries and process results."""
1385- return await self ._batch_query (
1386- [query .query for query in queries ],
1387- batch_size = batch_size ,
1388- )
1389-
13901415 async def _query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
13911416 """Asynchronously execute a query and process results."""
13921417 results = await self .search (query .query , query_params = query .params )
0 commit comments