Skip to content

Commit 60d78d4

Browse files
committed
Initial implementation of parallel UDFs using async
1 parent 9c649cf commit 60d78d4

File tree

1 file changed

+54
-35
lines changed

1 file changed

+54
-35
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,6 @@
9090
logger = utils.get_logger('singlestoredb.functions.ext.asgi')
9191

9292

93-
# If a number of processes is specified, create a pool of workers
94-
num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0)))
95-
if num_processes > 1:
96-
try:
97-
from ray.util.multiprocessing import Pool
98-
except ImportError:
99-
from multiprocessing import Pool
100-
func_map = Pool(num_processes).starmap
101-
else:
102-
func_map = itertools.starmap
103-
104-
10593
async def to_thread(
10694
func: Any, /, *args: Any, **kwargs: Dict[str, Any],
10795
) -> Any:
@@ -293,6 +281,53 @@ def cancel_on_event(
293281
)
294282

295283

284+
def identity(x: Any) -> Any:
285+
"""Identity function."""
286+
return x
287+
288+
289+
async def run_in_parallel(
290+
func: Callable[..., Any],
291+
params_list: Sequence[Sequence[Any]],
292+
cancel_event: threading.Event,
293+
limit: int = 10,
294+
transformer: Callable[[Any], Any] = identity,
295+
) -> List[Any]:
296+
""""
297+
Run a function in parallel with a limit on the number of concurrent tasks.
298+
299+
Parameters
300+
----------
301+
func : Callable
302+
The function to call in parallel
303+
params_list : Sequence[Sequence[Any]]
304+
The parameters to pass to the function
305+
cancel_event : threading.Event
306+
The event to check for cancellation
307+
limit : int
308+
The maximum number of concurrent tasks to run
309+
310+
Returns
311+
-------
312+
List[Any]
313+
The results of the function calls
314+
315+
"""
316+
semaphore = asyncio.Semaphore(limit)
317+
318+
async def worker(params: Sequence[Any]) -> Any:
319+
async with semaphore:
320+
cancel_on_event(cancel_event)
321+
if asyncio.iscoroutinefunction(func):
322+
return transformer(await func(*params))
323+
else:
324+
return transformer(await to_thread(func, *params))
325+
326+
tasks = [worker(p) for p in params_list]
327+
328+
return await asyncio.gather(*tasks)
329+
330+
296331
def build_udf_endpoint(
297332
func: Callable[..., Any],
298333
returns_data_format: str,
@@ -315,23 +350,15 @@ def build_udf_endpoint(
315350
"""
316351
if returns_data_format in ['scalar', 'list']:
317352

318-
is_async = asyncio.iscoroutinefunction(func)
319-
320353
async def do_func(
321354
cancel_event: threading.Event,
322355
timer: Timer,
323356
row_ids: Sequence[int],
324357
rows: Sequence[Sequence[Any]],
325358
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
326359
'''Call function on given rows of data.'''
327-
out = []
328360
async with timer('call_function'):
329-
for row in rows:
330-
cancel_on_event(cancel_event)
331-
if is_async:
332-
out.append(await func(*row))
333-
else:
334-
out.append(func(*row))
361+
out = await run_in_parallel(func, rows, cancel_event)
335362
return row_ids, list(zip(out))
336363

337364
return do_func
@@ -426,28 +453,20 @@ def build_tvf_endpoint(
426453
"""
427454
if returns_data_format in ['scalar', 'list']:
428455

429-
is_async = asyncio.iscoroutinefunction(func)
430-
431456
async def do_func(
432457
cancel_event: threading.Event,
433458
timer: Timer,
434459
row_ids: Sequence[int],
435460
rows: Sequence[Sequence[Any]],
436461
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
437462
'''Call function on given rows of data.'''
438-
out_ids: List[int] = []
439-
out = []
440-
# Call function on each row of data
441463
async with timer('call_function'):
442-
for i, row in zip(row_ids, rows):
443-
cancel_on_event(cancel_event)
444-
if is_async:
445-
res = await func(*row)
446-
else:
447-
res = func(*row)
448-
out.extend(as_list_of_tuples(res))
449-
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
450-
return out_ids, out
464+
items = await run_in_parallel(
465+
func, rows, cancel_event,
466+
transformer=as_list_of_tuples,
467+
)
468+
out = list(itertools.chain.from_iterable(items))
469+
return [row_ids[0]] * len(out), out
451470

452471
return do_func
453472

0 commit comments

Comments
 (0)