Skip to content

Commit 1953f97

Browse files
committed
Add concurrency limit option
1 parent 9d78722 commit 1953f97

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

singlestoredb/functions/decorator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _func(
103103
args: Optional[ParameterType] = None,
104104
returns: Optional[ReturnType] = None,
105105
timeout: Optional[int] = None,
106+
concurrency_limit: Optional[int] = None,
106107
) -> UDFType:
107108
"""Generic wrapper for UDF and TVF decorators."""
108109

@@ -112,6 +113,7 @@ def _func(
112113
args=expand_types(args),
113114
returns=expand_types(returns),
114115
timeout=timeout,
116+
concurrency_limit=concurrency_limit,
115117
).items() if v is not None
116118
}
117119

@@ -155,6 +157,7 @@ def udf(
155157
args: Optional[ParameterType] = None,
156158
returns: Optional[ReturnType] = None,
157159
timeout: Optional[int] = None,
160+
concurrency_limit: Optional[int] = None,
158161
) -> UDFType:
159162
"""
160163
Define a user-defined function (UDF).
@@ -185,6 +188,10 @@ def udf(
185188
timeout : int, optional
186189
The timeout in seconds for the UDF execution. If not specified,
187190
the global default timeout is used.
191+
concurrency_limit : int, optional
192+
The maximum number of concurrent subsets of rows that will be
193+
processed simultaneously by the UDF. If not specified,
194+
the global default concurrency limit is used.
188195
189196
Returns
190197
-------
@@ -197,4 +204,5 @@ def udf(
197204
args=args,
198205
returns=returns,
199206
timeout=timeout,
207+
concurrency_limit=concurrency_limit,
200208
)

singlestoredb/functions/ext/asgi.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ async def run_in_parallel(
301301
func: Callable[..., Any],
302302
params_list: Sequence[Sequence[Any]],
303303
cancel_event: threading.Event,
304-
limit: int = get_option('external_function.concurrency_limit'),
305304
transformer: Callable[[Any], Any] = identity,
306305
) -> List[Any]:
307306
""""
@@ -315,8 +314,6 @@ async def run_in_parallel(
315314
The parameters to pass to the function
316315
cancel_event : threading.Event
317316
The event to check for cancellation
318-
limit : int
319-
The maximum number of concurrent tasks to run
320317
transformer : Callable[[Any], Any]
321318
A function to transform the results
322319
@@ -326,6 +323,7 @@ async def run_in_parallel(
326323
The results of the function calls
327324
328325
"""
326+
limit = get_concurrency_limit(func)
329327
is_async = asyncio.iscoroutinefunction(func)
330328

331329
async def call(batch: Sequence[Any]) -> Any:
@@ -352,6 +350,29 @@ async def thread_call(batch: Sequence[Any]) -> Any:
352350
return list(itertools.chain.from_iterable(results))
353351

354352

353+
def get_concurrency_limit(func: Callable[..., Any]) -> int:
354+
"""
355+
Get the concurrency limit for a function.
356+
357+
Parameters
358+
----------
359+
func : Callable
360+
The function to get the concurrency limit for
361+
362+
Returns
363+
-------
364+
int
365+
The concurrency limit for the function
366+
367+
"""
368+
return max(
369+
1, func._singlestoredb_attrs.get( # type: ignore
370+
'concurrency_limit',
371+
get_option('external_function.concurrency_limit'),
372+
),
373+
)
374+
375+
355376
def build_udf_endpoint(
356377
func: Callable[..., Any],
357378
returns_data_format: str,

0 commit comments

Comments
 (0)