|
1 | 1 | import asyncio
|
2 |
| -from typing import Literal, Optional, get_args |
| 2 | +from typing import Literal, Optional, Union, get_args |
3 | 3 |
|
4 | 4 | import click
|
5 | 5 | from loguru import logger
|
|
13 | 13 | TransformersDatasetRequestGenerator,
|
14 | 14 | )
|
15 | 15 | from guidellm.request.base import RequestGenerator
|
16 |
| -from guidellm.utils import BenchmarkReportProgress |
| 16 | +from guidellm.utils import BenchmarkReportProgress, cli_params |
17 | 17 |
|
18 | 18 | __all__ = ["generate_benchmark_report"]
|
19 | 19 |
|
|
120 | 120 | )
|
121 | 121 | @click.option(
|
122 | 122 | "--max-requests",
|
123 |
| - type=int, |
| 123 | + type=cli_params.MAX_REQUESTS, |
124 | 124 | default=None,
|
125 | 125 | help=(
|
126 | 126 | "The maximum number of requests for each benchmark run. "
|
@@ -161,7 +161,7 @@ def generate_benchmark_report_cli(
|
161 | 161 | rate_type: ProfileGenerationMode,
|
162 | 162 | rate: Optional[float],
|
163 | 163 | max_seconds: Optional[int],
|
164 |
| - max_requests: Optional[int], |
| 164 | + max_requests: Union[Literal["dataset"], int, None], |
165 | 165 | output_path: str,
|
166 | 166 | enable_continuous_refresh: bool,
|
167 | 167 | ):
|
@@ -194,7 +194,7 @@ def generate_benchmark_report(
|
194 | 194 | rate_type: ProfileGenerationMode,
|
195 | 195 | rate: Optional[float],
|
196 | 196 | max_seconds: Optional[int],
|
197 |
| - max_requests: Optional[int], |
| 197 | + max_requests: Union[Literal["dataset"], int, None], |
198 | 198 | output_path: str,
|
199 | 199 | cont_refresh_table: bool,
|
200 | 200 | ) -> GuidanceReport:
|
@@ -256,13 +256,18 @@ def generate_benchmark_report(
|
256 | 256 | else:
|
257 | 257 | raise ValueError(f"Unknown data type: {data_type}")
|
258 | 258 |
|
| 259 | + if data_type == "emulated" and max_requests == "dataset": |
| 260 | + raise ValueError("Cannot use 'dataset' for emulated data") |
| 261 | + |
259 | 262 | # Create executor
|
260 | 263 | executor = Executor(
|
261 | 264 | backend=backend_inst,
|
262 | 265 | request_generator=request_generator,
|
263 | 266 | mode=rate_type,
|
264 | 267 | rate=rate if rate_type in ("constant", "poisson") else None,
|
265 |
| - max_number=max_requests, |
| 268 | + max_number=( |
| 269 | + len(request_generator) if max_requests == "dataset" else max_requests |
| 270 | + ), |
266 | 271 | max_duration=max_seconds,
|
267 | 272 | )
|
268 | 273 |
|
|
0 commit comments