9090logger = utils .get_logger ('singlestoredb.functions.ext.asgi' )
9191
9292
93- # If a number of processes is specified, create a pool of workers
94- num_processes = max (0 , int (os .environ .get ('SINGLESTOREDB_EXT_NUM_PROCESSES' , 0 )))
95- if num_processes > 1 :
96- try :
97- from ray .util .multiprocessing import Pool
98- except ImportError :
99- from multiprocessing import Pool
100- func_map = Pool (num_processes ).starmap
101- else :
102- func_map = itertools .starmap
103-
104-
10593async def to_thread (
10694 func : Any , / , * args : Any , ** kwargs : Dict [str , Any ],
10795) -> Any :
@@ -293,6 +281,53 @@ def cancel_on_event(
293281 )
294282
295283
284+ def identity (x : Any ) -> Any :
285+ """Identity function."""
286+ return x
287+
288+
289+ async def run_in_parallel (
290+ func : Callable [..., Any ],
291+ params_list : Sequence [Sequence [Any ]],
292+ cancel_event : threading .Event ,
293+ limit : int = 10 ,
294+ transformer : Callable [[Any ], Any ] = identity ,
295+ ) -> List [Any ]:
296+ """"
297+ Run a function in parallel with a limit on the number of concurrent tasks.
298+
299+ Parameters
300+ ----------
301+ func : Callable
302+ The function to call in parallel
303+ params_list : Sequence[Sequence[Any]]
304+ The parameters to pass to the function
305+ cancel_event : threading.Event
306+ The event to check for cancellation
307+ limit : int
308+ The maximum number of concurrent tasks to run
309+
310+ Returns
311+ -------
312+ List[Any]
313+ The results of the function calls
314+
315+ """
316+ semaphore = asyncio .Semaphore (limit )
317+
318+ async def worker (params : Sequence [Any ]) -> Any :
319+ async with semaphore :
320+ cancel_on_event (cancel_event )
321+ if asyncio .iscoroutinefunction (func ):
322+ return transformer (await func (* params ))
323+ else :
324+ return transformer (await to_thread (func , * params ))
325+
326+ tasks = [worker (p ) for p in params_list ]
327+
328+ return await asyncio .gather (* tasks )
329+
330+
296331def build_udf_endpoint (
297332 func : Callable [..., Any ],
298333 returns_data_format : str ,
@@ -315,23 +350,15 @@ def build_udf_endpoint(
315350 """
316351 if returns_data_format in ['scalar' , 'list' ]:
317352
318- is_async = asyncio .iscoroutinefunction (func )
319-
320353 async def do_func (
321354 cancel_event : threading .Event ,
322355 timer : Timer ,
323356 row_ids : Sequence [int ],
324357 rows : Sequence [Sequence [Any ]],
325358 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
326359 '''Call function on given rows of data.'''
327- out = []
328360 async with timer ('call_function' ):
329- for row in rows :
330- cancel_on_event (cancel_event )
331- if is_async :
332- out .append (await func (* row ))
333- else :
334- out .append (func (* row ))
361+ out = await run_in_parallel (func , rows , cancel_event )
335362 return row_ids , list (zip (out ))
336363
337364 return do_func
@@ -426,28 +453,20 @@ def build_tvf_endpoint(
426453 """
427454 if returns_data_format in ['scalar' , 'list' ]:
428455
429- is_async = asyncio .iscoroutinefunction (func )
430-
431456 async def do_func (
432457 cancel_event : threading .Event ,
433458 timer : Timer ,
434459 row_ids : Sequence [int ],
435460 rows : Sequence [Sequence [Any ]],
436461 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
437462 '''Call function on given rows of data.'''
438- out_ids : List [int ] = []
439- out = []
440- # Call function on each row of data
441463 async with timer ('call_function' ):
442- for i , row in zip (row_ids , rows ):
443- cancel_on_event (cancel_event )
444- if is_async :
445- res = await func (* row )
446- else :
447- res = func (* row )
448- out .extend (as_list_of_tuples (res ))
449- out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
450- return out_ids , out
464+ items = await run_in_parallel (
465+ func , rows , cancel_event ,
466+ transformer = as_list_of_tuples ,
467+ )
468+ out = list (itertools .chain .from_iterable (items ))
469+ return [row_ids [0 ]] * len (out ), out
451470
452471 return do_func
453472
0 commit comments