5353from typing import Callable
5454from typing import Dict
5555from typing import Iterable
56+ from typing import Iterator
5657from typing import List
5758from typing import Optional
5859from typing import Sequence
@@ -286,11 +287,21 @@ def identity(x: Any) -> Any:
286287 return x
287288
288289
290+ def chunked (seq : Sequence [Any ], max_chunks : int ) -> Iterator [Sequence [Any ]]:
291+ """Yield up to max_chunks chunks from seq, splitting as evenly as possible."""
292+ n = len (seq )
293+ if max_chunks <= 0 or max_chunks > n :
294+ max_chunks = n
295+ chunk_size = (n + max_chunks - 1 ) // max_chunks # ceil division
296+ for i in range (0 , n , chunk_size ):
297+ yield seq [i :i + chunk_size ]
298+
299+
289300async def run_in_parallel (
290301 func : Callable [..., Any ],
291302 params_list : Sequence [Sequence [Any ]],
292303 cancel_event : threading .Event ,
293- limit : int = 10 ,
304+ limit : int = get_option ( 'external_function.concurrency_limit' ) ,
294305 transformer : Callable [[Any ], Any ] = identity ,
295306) -> List [Any ]:
296307 """"
@@ -306,26 +317,39 @@ async def run_in_parallel(
306317 The event to check for cancellation
307318 limit : int
308319 The maximum number of concurrent tasks to run
320+ transformer : Callable[[Any], Any]
321+ A function to transform the results
309322
310323 Returns
311324 -------
312325 List[Any]
313326 The results of the function calls
314327
315328 """
316- semaphore = asyncio .Semaphore ( limit )
329+ is_async = asyncio .iscoroutinefunction ( func )
317330
318- async def worker (params : Sequence [Any ]) -> Any :
319- async with semaphore :
331+ async def call (batch : Sequence [Any ]) -> Any :
332+ """Loop over batches of parameters and call the function."""
333+ res = []
334+ for params in batch :
320335 cancel_on_event (cancel_event )
321- if asyncio . iscoroutinefunction ( func ) :
322- return transformer (await func (* params ))
336+ if is_async :
337+ res . append ( transformer (await func (* params ) ))
323338 else :
324- return transformer (await to_thread (func , * params ))
339+ res .append (transformer (func (* params )))
340+ return res
341+
342+ async def thread_call (batch : Sequence [Any ]) -> Any :
343+ if is_async :
344+ return await call (batch )
345+ return await to_thread (lambda : asyncio .run (call (batch )))
346+
347+ # Create tasks in chunks to limit concurrency
348+ tasks = [thread_call (batch ) for batch in chunked (params_list , limit )]
325349
326- tasks = [ worker ( p ) for p in params_list ]
350+ results = await asyncio . gather ( * tasks )
327351
328- return await asyncio . gather ( * tasks )
352+ return list ( itertools . chain . from_iterable ( results ) )
329353
330354
331355def build_udf_endpoint (
0 commit comments