Skip to content

Commit 9d78722

Browse files
committed
Only use thread for a batch of rows not each row
1 parent 60d78d4 commit 9d78722

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

singlestoredb/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,13 @@
444444
environ=['SINGLESTOREDB_EXT_FUNC_TIMEOUT'],
445445
)
446446

447+
register_option(
448+
'external_function.concurrency_limit', 'int', check_int, 1,
449+
'Specifies the maximum number of subsets of a batch of rows '
450+
'to process simultaneously.',
451+
environ=['SINGLESTOREDB_EXT_FUNC_CONCURRENCY_LIMIT'],
452+
)
453+
447454
#
448455
# Debugging options
449456
#

singlestoredb/functions/ext/asgi.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from typing import Callable
5454
from typing import Dict
5555
from typing import Iterable
56+
from typing import Iterator
5657
from typing import List
5758
from typing import Optional
5859
from typing import Sequence
@@ -286,11 +287,21 @@ def identity(x: Any) -> Any:
286287
return x
287288

288289

290+
def chunked(seq: Sequence[Any], max_chunks: int) -> Iterator[Sequence[Any]]:
291+
"""Yield up to max_chunks chunks from seq, splitting as evenly as possible."""
292+
n = len(seq)
293+
if max_chunks <= 0 or max_chunks > n:
294+
max_chunks = n
295+
chunk_size = (n + max_chunks - 1) // max_chunks # ceil division
296+
for i in range(0, n, chunk_size):
297+
yield seq[i:i + chunk_size]
298+
299+
289300
async def run_in_parallel(
290301
func: Callable[..., Any],
291302
params_list: Sequence[Sequence[Any]],
292303
cancel_event: threading.Event,
293-
limit: int = 10,
304+
limit: int = get_option('external_function.concurrency_limit'),
294305
transformer: Callable[[Any], Any] = identity,
295306
) -> List[Any]:
296307
""""
@@ -306,26 +317,39 @@ async def run_in_parallel(
306317
The event to check for cancellation
307318
limit : int
308319
The maximum number of concurrent tasks to run
320+
transformer : Callable[[Any], Any]
321+
A function to transform the results
309322
310323
Returns
311324
-------
312325
List[Any]
313326
The results of the function calls
314327
315328
"""
316-
semaphore = asyncio.Semaphore(limit)
329+
is_async = asyncio.iscoroutinefunction(func)
317330

318-
async def worker(params: Sequence[Any]) -> Any:
319-
async with semaphore:
331+
async def call(batch: Sequence[Any]) -> Any:
332+
"""Loop over batches of parameters and call the function."""
333+
res = []
334+
for params in batch:
320335
cancel_on_event(cancel_event)
321-
if asyncio.iscoroutinefunction(func):
322-
return transformer(await func(*params))
336+
if is_async:
337+
res.append(transformer(await func(*params)))
323338
else:
324-
return transformer(await to_thread(func, *params))
339+
res.append(transformer(func(*params)))
340+
return res
341+
342+
async def thread_call(batch: Sequence[Any]) -> Any:
343+
if is_async:
344+
return await call(batch)
345+
return await to_thread(lambda: asyncio.run(call(batch)))
346+
347+
# Create tasks in chunks to limit concurrency
348+
tasks = [thread_call(batch) for batch in chunked(params_list, limit)]
325349

326-
tasks = [worker(p) for p in params_list]
350+
results = await asyncio.gather(*tasks)
327351

328-
return await asyncio.gather(*tasks)
352+
return list(itertools.chain.from_iterable(results))
329353

330354

331355
def build_udf_endpoint(

0 commit comments

Comments
 (0)