diff --git a/src/guidellm/scheduler/base.py b/src/guidellm/scheduler/base.py index 602166b0..f0087330 100644 --- a/src/guidellm/scheduler/base.py +++ b/src/guidellm/scheduler/base.py @@ -299,13 +299,13 @@ async def _run_async( self, benchmark: TextGenerationBenchmark, end_time: float, max_number: float ) -> AsyncGenerator[Union[TextGenerationResult, TextGenerationError], None]: tasks = [] - completed = 0 + pending = asyncio.Semaphore(settings.max_concurrency) for index, (request, submit_at) in enumerate( zip(self.generator, self.load_generator.times()) ): - while (index + 1 - completed) >= settings.max_concurrency: - await asyncio.sleep(0.1) + # wait for number of pending tasks to be >= max_concurrency + await pending.acquire() if index >= max_number or time.time() >= end_time or submit_at >= end_time: break @@ -317,8 +317,9 @@ async def _run_async( ) def _completed(_task: asyncio.Task) -> None: - nonlocal completed - completed += 1 + # NOTE: this is only ok because we don't use threads/processes + nonlocal pending + pending.release() _res = _task.result() if _res: @@ -333,7 +334,7 @@ def _completed(_task: asyncio.Task) -> None: tasks.append(task) # release control to the event loop for other tasks - await asyncio.sleep(0.001) + await asyncio.sleep(0) for compl_task in asyncio.as_completed(tasks): task_res = await compl_task