Skip to content

Commit 5a8763c

Browse files
author
Dmytro Parfeniuk
committed
--batch-size CLI parameter is added
1 parent ecf2984 commit 5a8763c

File tree

3 files changed

+158
-18
lines changed

3 files changed

+158
-18
lines changed

src/guidellm/executor/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class Executor:
6363
:param max_duration: Maximum duration for generating requests for the scheduler,
6464
(a single report run), or None.
6565
:type max_duration: Optional[float]
66+
:type batch_size: Optional[int]
6667
"""
6768

6869
def __init__(
@@ -73,12 +74,14 @@ def __init__(
7374
rate: Optional[Union[float, Sequence[float]]] = None,
7475
max_number: Optional[int] = None,
7576
max_duration: Optional[float] = None,
77+
batch_size: Optional[int] = None,
7678
):
7779
self._backend = backend
7880
self._generator = request_generator
7981
self._max_number = max_number
8082
self._max_duration = max_duration
8183
self._profile_generator = ProfileGenerator(mode=mode, rate=rate)
84+
self._batch_size = batch_size
8285
logger.info("Executor initialized with mode: {}, rate: {}", mode, rate)
8386

8487
@property
@@ -131,6 +134,16 @@ def max_duration(self) -> Optional[float]:
131134
"""
132135
return self._max_duration
133136

137+
@property
138+
def batch_size(self) -> Optional[int]:
139+
"""
140+
Returns the number batch size.
141+
142+
:return: Number batch size.
143+
:rtype: Optional[int]
144+
"""
145+
return self._batch_size
146+
134147
async def run(self) -> AsyncGenerator[ExecutorResult, None]:
135148
"""
136149
Runs the Executor, generating and scheduling tasks based on the profile
@@ -154,6 +167,7 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]:
154167
# limits args
155168
"max_number": self.max_number,
156169
"max_duration": self.max_duration,
170+
"batch_size": self.batch_size,
157171
}
158172
profile_index = -1
159173
logger.info("Starting Executor run")
@@ -175,6 +189,7 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]:
175189
rate=profile.load_gen_rate,
176190
max_number=self.max_number or profile.args.get("max_number", None),
177191
max_duration=self.max_duration,
192+
batch_size=self.batch_size,
178193
)
179194
profile_index += 1
180195

src/guidellm/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@
131131
"the server's performance to stabilize."
132132
),
133133
)
134+
@click.option(
135+
"--batch-size",
136+
type=int,
137+
default=None,
138+
help="The batch size of inference requests.",
139+
)
134140
@click.option(
135141
"--output-path",
136142
type=str,
@@ -162,6 +168,7 @@ def generate_benchmark_report_cli(
162168
rate: Optional[float],
163169
max_seconds: Optional[int],
164170
max_requests: Union[Literal["dataset"], int, None],
171+
batch_size: Optional[int],
165172
output_path: str,
166173
enable_continuous_refresh: bool,
167174
):
@@ -179,6 +186,7 @@ def generate_benchmark_report_cli(
179186
rate=rate,
180187
max_seconds=max_seconds,
181188
max_requests=max_requests,
189+
batch_size=batch_size,
182190
output_path=output_path,
183191
cont_refresh_table=enable_continuous_refresh,
184192
)
@@ -195,6 +203,7 @@ def generate_benchmark_report(
195203
rate: Optional[float],
196204
max_seconds: Optional[int],
197205
max_requests: Union[Literal["dataset"], int, None],
206+
batch_size: Optional[int],
198207
output_path: str,
199208
cont_refresh_table: bool,
200209
) -> GuidanceReport:
@@ -269,6 +278,7 @@ def generate_benchmark_report(
269278
len(request_generator) if max_requests == "dataset" else max_requests
270279
),
271280
max_duration=max_seconds,
281+
batch_size=batch_size,
272282
)
273283

274284
# Run executor
@@ -281,6 +291,7 @@ def generate_benchmark_report(
281291
"rate": rate,
282292
"max_number": max_requests,
283293
"max_duration": max_seconds,
294+
"batch_size": batch_size,
284295
},
285296
)
286297
report = asyncio.run(_run_executor_for_result(executor))

src/guidellm/scheduler/base.py

Lines changed: 132 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import time
44
from dataclasses import dataclass
5-
from typing import AsyncGenerator, Literal, Optional, Union, get_args
5+
from typing import AsyncGenerator, List, Literal, Optional, Tuple, Union, get_args
66

77
from loguru import logger
88

@@ -35,13 +35,18 @@ class SchedulerResult:
3535
:type benchmark: TextGenerationBenchmark
3636
:param current_result: The result of the current request, if any.
3737
:type current_result: Optional[Union[TextGenerationResult, Exception]]
38+
:param batch_results: The result of the current batch of requests, if any
39+
:type batch_results: Optional[List[Union[TextGenerationResult, TextGenerationError]]]
3840
"""
3941

4042
completed: bool
4143
count_total: int
4244
count_completed: int
4345
benchmark: TextGenerationBenchmark
4446
current_result: Optional[Union[TextGenerationResult, TextGenerationError]] = None
47+
batch_results: Optional[List[Union[TextGenerationResult, TextGenerationError]]] = (
48+
None
49+
)
4550

4651

4752
class Scheduler:
@@ -74,6 +79,7 @@ def __init__(
7479
rate: Optional[float] = None,
7580
max_number: Optional[int] = None,
7681
max_duration: Optional[float] = None,
82+
batch_size: Optional[int] = None,
7783
):
7884
logger.info(
7985
"Scheduler initialized with params: generator={}, worker={}, mode={}, "
@@ -114,13 +120,20 @@ def __init__(
114120
logger.error(err)
115121
raise err
116122

123+
if batch_size and batch_size <= 0:
124+
err = ValueError(f"batch_size must be > 0, given: {batch_size}")
125+
logger.error(err)
126+
raise err
127+
117128
self._generator = generator
118129
self._worker = worker
119130
self._mode = mode
120131
self._rate = rate
121132
self._max_number = max_number
122133
self._max_duration = max_duration
123134

135+
self._batch_size = batch_size
136+
124137
self._load_generator = LoadGenerator(mode, rate)
125138

126139
@property
@@ -209,6 +222,20 @@ def benchmark_mode(self) -> Literal["asynchronous", "synchronous", "throughput"]
209222

210223
return "asynchronous"
211224

225+
@property
226+
def batch_size(self) -> Optional[int]:
227+
"""
228+
Returns the maximum number of requests to generate.
229+
230+
:return: Maximum number of requests or None.
231+
:rtype: Optional[int]
232+
"""
233+
234+
return self._batch_size
235+
236+
async def _run_batch(self):
237+
pass
238+
212239
async def run(self) -> AsyncGenerator[SchedulerResult, None]:
213240
"""
214241
Run the scheduler to process requests based on the configured mode, rate,
@@ -223,15 +250,17 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
223250
start_time = time.time()
224251
end_time = start_time + self.max_duration if self.max_duration else math.inf
225252
max_number = float(self.max_number) if self.max_number else math.inf
226-
runner = self._run_sync if self._mode == "synchronous" else self._run_async
227253
count_total = (
228254
self.max_number
229255
if self.max_number
230-
else round(self.max_duration)
231-
if self.max_duration
232-
else 0
256+
else round(self.max_duration) if self.max_duration else 0
233257
)
234258

259+
if self.batch_size:
260+
runner = self._run_batch
261+
else:
262+
runner = self._run_sync if self._mode == "synchronous" else self._run_async
263+
235264
# yield initial result for progress tracking
236265
yield SchedulerResult(
237266
completed=False,
@@ -243,21 +272,30 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
243272
run_count = 0
244273
async for res in runner(benchmark, end_time, max_number):
245274
run_count += 1
275+
246276
count_completed = (
247277
min(run_count, self.max_number)
248278
if self.max_number
249-
else round(time.time() - start_time)
250-
if self.max_duration
251-
else 0
279+
else round(time.time() - start_time) if self.max_duration else 0
252280
)
253281

254-
yield SchedulerResult(
255-
completed=False,
256-
count_total=count_total,
257-
count_completed=count_completed,
258-
benchmark=benchmark,
259-
current_result=res,
260-
)
282+
if self.batch_size:
283+
284+
yield SchedulerResult(
285+
completed=False,
286+
count_total=count_total,
287+
count_completed=count_completed,
288+
benchmark=benchmark,
289+
batch_results=res,
290+
)
291+
else:
292+
yield SchedulerResult(
293+
completed=False,
294+
count_total=count_total,
295+
count_completed=count_completed,
296+
benchmark=benchmark,
297+
current_result=res,
298+
)
261299

262300
logger.info("Scheduler run completed")
263301

@@ -267,9 +305,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
267305
count_completed=(
268306
benchmark.request_count + benchmark.error_count
269307
if self.max_number
270-
else round(time.time() - start_time)
271-
if self.max_duration
272-
else 0
308+
else round(time.time() - start_time) if self.max_duration else 0
273309
),
274310
benchmark=benchmark,
275311
)
@@ -372,3 +408,81 @@ async def _submit_task_coroutine(
372408
logger.warning("Request {} failed: {}", request, exc)
373409

374410
return TextGenerationError(request=request, message=str(exc))
411+
412+
async def _run_batch(
413+
self, benchmark: TextGenerationBenchmark, end_time: float, max_number: float
414+
) -> AsyncGenerator[SchedulerResult, None]:
415+
416+
if self.batch_size is None:
417+
raise ValueError("--batch-size CLI parameter is not set")
418+
419+
batch = []
420+
count_completed = 0
421+
422+
for request, submit_at in zip(self.generator, self.load_generator.times()):
423+
if time.time() >= end_time or count_completed >= max_number:
424+
break
425+
426+
if len(batch) < self.batch_size:
427+
batch.append((request, submit_at))
428+
429+
if len(batch) >= self.batch_size:
430+
results = await self._process_batch(batch, benchmark, end_time)
431+
count_completed += len(
432+
[r for r in results if not isinstance(r, TextGenerationError)]
433+
)
434+
435+
yield results
436+
batch = []
437+
438+
if batch:
439+
results = await self._process_batch(batch, benchmark, end_time)
440+
count_completed += len(
441+
[r for r in results if not isinstance(r, TextGenerationError)]
442+
)
443+
yield results
444+
445+
async def _process_batch(
446+
self,
447+
batch: List[Tuple[TextGenerationRequest, float]],
448+
benchmark: TextGenerationBenchmark,
449+
end_time: float,
450+
) -> List[Union[TextGenerationResult, TextGenerationError]]:
451+
try:
452+
453+
benchmark.request_started()
454+
tasks = [
455+
self._delayed_submit(request, submit_at, end_time)
456+
for request, submit_at in batch
457+
]
458+
459+
timeout = end_time - time.time() if end_time < math.inf else None
460+
461+
results = await asyncio.wait_for(
462+
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout
463+
)
464+
processed_results = []
465+
for (req, _), result in zip(batch, results):
466+
if isinstance(result, Exception):
467+
error = TextGenerationError(request=req, message=str(result))
468+
benchmark.request_completed(error)
469+
processed_results.append(error)
470+
else:
471+
benchmark.request_completed(result)
472+
processed_results.append(result)
473+
return processed_results
474+
except asyncio.TimeoutError:
475+
return [
476+
TextGenerationError(request=req, message="Batch timeout")
477+
for req, _ in batch
478+
]
479+
480+
async def _delayed_submit(
481+
self, request: TextGenerationRequest, submit_at: float, end_time: float
482+
) -> Union[TextGenerationResult, TextGenerationError]:
483+
if submit_at > time.time():
484+
await asyncio.sleep(submit_at - time.time())
485+
if time.time() >= end_time:
486+
raise asyncio.TimeoutError("Submission time exceeded end_time")
487+
488+
return await self._worker.submit(request)

0 commit comments

Comments
 (0)