diff --git a/src/guidellm/executor/base.py b/src/guidellm/executor/base.py index 865ab30d..1bca4373 100644 --- a/src/guidellm/executor/base.py +++ b/src/guidellm/executor/base.py @@ -63,6 +63,7 @@ class Executor: :param max_duration: Maximum duration for generating requests for the scheduler, (a single report run), or None. :type max_duration: Optional[float] + :type batch_size: Optional[int] """ def __init__( @@ -73,12 +74,14 @@ def __init__( rate: Optional[Union[float, Sequence[float]]] = None, max_number: Optional[int] = None, max_duration: Optional[float] = None, + batch_size: Optional[int] = None, ): self._backend = backend self._generator = request_generator self._max_number = max_number self._max_duration = max_duration self._profile_generator = ProfileGenerator(mode=mode, rate=rate) + self._batch_size = batch_size logger.info("Executor initialized with mode: {}, rate: {}", mode, rate) @property @@ -131,6 +134,16 @@ def max_duration(self) -> Optional[float]: """ return self._max_duration + @property + def batch_size(self) -> Optional[int]: + """ + Returns the number batch size. + + :return: Number batch size. + :rtype: Optional[int] + """ + return self._batch_size + async def run(self) -> AsyncGenerator[ExecutorResult, None]: """ Runs the Executor, generating and scheduling tasks based on the profile @@ -154,6 +167,7 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]: # limits args "max_number": self.max_number, "max_duration": self.max_duration, + "batch_size": self.batch_size, } profile_index = -1 logger.info("Starting Executor run") @@ -175,6 +189,7 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]: rate=profile.load_gen_rate, max_number=self.max_number or profile.args.get("max_number", None), max_duration=self.max_duration, + batch_size=self.batch_size, ) profile_index += 1 diff --git a/src/guidellm/main.py b/src/guidellm/main.py index 4016ecec..7ab868c2 100644 --- a/src/guidellm/main.py +++ b/src/guidellm/main.py @@ -131,6 +131,12 @@ "the server's performance to stabilize." ), ) +@click.option( + "--batch-size", + type=int, + default=None, + help="The batch size of inference requests.", +) @click.option( "--output-path", type=str, @@ -162,6 +168,7 @@ def generate_benchmark_report_cli( rate: Optional[float], max_seconds: Optional[int], max_requests: Union[Literal["dataset"], int, None], + batch_size: Optional[int], output_path: str, enable_continuous_refresh: bool, ): @@ -179,6 +186,7 @@ def generate_benchmark_report_cli( rate=rate, max_seconds=max_seconds, max_requests=max_requests, + batch_size=batch_size, output_path=output_path, cont_refresh_table=enable_continuous_refresh, ) @@ -195,6 +203,7 @@ def generate_benchmark_report( rate: Optional[float], max_seconds: Optional[int], max_requests: Union[Literal["dataset"], int, None], + batch_size: Optional[int], output_path: str, cont_refresh_table: bool, ) -> GuidanceReport: @@ -269,6 +278,7 @@ def generate_benchmark_report( len(request_generator) if max_requests == "dataset" else max_requests ), max_duration=max_seconds, + batch_size=batch_size, ) # Run executor @@ -281,6 +291,7 @@ def generate_benchmark_report( "rate": rate, "max_number": max_requests, "max_duration": max_seconds, + "batch_size": batch_size, }, ) report = asyncio.run(_run_executor_for_result(executor)) diff --git a/src/guidellm/scheduler/base.py b/src/guidellm/scheduler/base.py index 602166b0..8b717b58 100644 --- a/src/guidellm/scheduler/base.py +++ b/src/guidellm/scheduler/base.py @@ -2,7 +2,7 @@ import math import time from dataclasses import dataclass -from typing import AsyncGenerator, Literal, Optional, Union, get_args +from typing import AsyncGenerator, List, Literal, Optional, Tuple, Union, get_args from loguru import logger @@ -35,6 +35,10 @@ class SchedulerResult: :type benchmark: TextGenerationBenchmark :param current_result: The result of the current request, if any. :type current_result: Optional[Union[TextGenerationResult, Exception]] + :param batch_results: The result of the current batch of requests, if any + :type batch_results: Optional[List[ + Union[TextGenerationResult, TextGenerationError]] + ] """ completed: bool @@ -42,6 +46,9 @@ class SchedulerResult: count_completed: int benchmark: TextGenerationBenchmark current_result: Optional[Union[TextGenerationResult, TextGenerationError]] = None + batch_results: Optional[List[Union[TextGenerationResult, TextGenerationError]]] = ( + None + ) class Scheduler: @@ -74,6 +81,7 @@ def __init__( rate: Optional[float] = None, max_number: Optional[int] = None, max_duration: Optional[float] = None, + batch_size: Optional[int] = None, ): logger.info( "Scheduler initialized with params: generator={}, worker={}, mode={}, " @@ -114,6 +122,11 @@ def __init__( logger.error(err) raise err + if batch_size and batch_size <= 0: + err = ValueError(f"batch_size must be > 0, given: {batch_size}") + logger.error(err) + raise err + self._generator = generator self._worker = worker self._mode = mode @@ -121,6 +134,8 @@ def __init__( self._max_number = max_number self._max_duration = max_duration + self._batch_size = batch_size + self._load_generator = LoadGenerator(mode, rate) @property @@ -209,6 +224,17 @@ def benchmark_mode(self) -> Literal["asynchronous", "synchronous", "throughput"] return "asynchronous" + @property + def batch_size(self) -> Optional[int]: + """ + Returns the maximum number of requests to generate. + + :return: Maximum number of requests or None. + :rtype: Optional[int] + """ + + return self._batch_size + async def run(self) -> AsyncGenerator[SchedulerResult, None]: """ Run the scheduler to process requests based on the configured mode, rate, @@ -223,15 +249,17 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]: start_time = time.time() end_time = start_time + self.max_duration if self.max_duration else math.inf max_number = float(self.max_number) if self.max_number else math.inf - runner = self._run_sync if self._mode == "synchronous" else self._run_async count_total = ( self.max_number if self.max_number - else round(self.max_duration) - if self.max_duration - else 0 + else round(self.max_duration) if self.max_duration else 0 ) + if self.batch_size: + runner = self._run_batch + else: + runner = self._run_sync if self._mode == "synchronous" else self._run_async + # yield initial result for progress tracking yield SchedulerResult( completed=False, @@ -243,21 +271,30 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]: run_count = 0 async for res in runner(benchmark, end_time, max_number): run_count += 1 + count_completed = ( min(run_count, self.max_number) if self.max_number - else round(time.time() - start_time) - if self.max_duration - else 0 + else round(time.time() - start_time) if self.max_duration else 0 ) - yield SchedulerResult( - completed=False, - count_total=count_total, - count_completed=count_completed, - benchmark=benchmark, - current_result=res, - ) + if self.batch_size: + + yield SchedulerResult( + completed=False, + count_total=count_total, + count_completed=count_completed, + benchmark=benchmark, + batch_results=res, + ) + else: + yield SchedulerResult( + completed=False, + count_total=count_total, + count_completed=count_completed, + benchmark=benchmark, + current_result=res, + ) logger.info("Scheduler run completed") @@ -267,9 +304,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]: count_completed=( benchmark.request_count + benchmark.error_count if self.max_number - else round(time.time() - start_time) - if self.max_duration - else 0 + else round(time.time() - start_time) if self.max_duration else 0 ), benchmark=benchmark, ) @@ -372,3 +407,81 @@ async def _submit_task_coroutine( logger.warning("Request {} failed: {}", request, exc) return TextGenerationError(request=request, message=str(exc)) + + async def _run_batch( + self, benchmark: TextGenerationBenchmark, end_time: float, max_number: float + ) -> AsyncGenerator[SchedulerResult, None]: + + if self.batch_size is None: + raise ValueError("--batch-size CLI parameter is not set") + + batch = [] + count_completed = 0 + + for request, submit_at in zip(self.generator, self.load_generator.times()): + if time.time() >= end_time or count_completed >= max_number: + break + + if len(batch) < self.batch_size: + batch.append((request, submit_at)) + + if len(batch) >= self.batch_size: + results = await self._process_batch(batch, benchmark, end_time) + count_completed += len( + [r for r in results if not isinstance(r, TextGenerationError)] + ) + + yield results + batch = [] + + if batch: + results = await self._process_batch(batch, benchmark, end_time) + count_completed += len( + [r for r in results if not isinstance(r, TextGenerationError)] + ) + yield results + + async def _process_batch( + self, + batch: List[Tuple[TextGenerationRequest, float]], + benchmark: TextGenerationBenchmark, + end_time: float, + ) -> List[Union[TextGenerationResult, TextGenerationError]]: + try: + + benchmark.request_started() + tasks = [ + self._delayed_submit(request, submit_at, end_time) + for request, submit_at in batch + ] + + timeout = end_time - time.time() if end_time < math.inf else None + + results = await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), timeout=timeout + ) + processed_results = [] + for (req, _), result in zip(batch, results): + if isinstance(result, Exception): + error = TextGenerationError(request=req, message=str(result)) + benchmark.request_completed(error) + processed_results.append(error) + else: + benchmark.request_completed(result) + processed_results.append(result) + return processed_results + except asyncio.TimeoutError: + return [ + TextGenerationError(request=req, message="Batch timeout") + for req, _ in batch + ] + + async def _delayed_submit( + self, request: TextGenerationRequest, submit_at: float, end_time: float + ) -> Union[TextGenerationResult, TextGenerationError]: + if submit_at > time.time(): + await asyncio.sleep(submit_at - time.time()) + if time.time() >= end_time: + raise asyncio.TimeoutError("Submission time exceeded end_time") + + return await self._worker.submit(request) diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 82de3edf..80c15cd8 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -252,6 +252,7 @@ def test_generate_benchmark_report_invoke_smoke( max_requests=10, output_path="benchmark_report.json", cont_refresh_table=False, + batch_size=None, ) assert report is not None @@ -308,6 +309,7 @@ def test_generate_benchmark_report_emulated_with_dataset_requests( rate=None, max_seconds=10, max_requests="dataset", + batch_size=None, output_path="benchmark_report.json", cont_refresh_table=False, ) @@ -397,6 +399,7 @@ def test_generate_benchmark_report_openai_limited_by_file_dataset( rate=rate, max_seconds=None, max_requests="dataset", + batch_size=None, output_path="benchmark_report.json", cont_refresh_table=False, )