|
1 | 1 | import asyncio |
2 | | -from typing import Literal, Optional, get_args |
| 2 | +from typing import Literal, Optional, Union, get_args |
3 | 3 |
|
4 | 4 | import click |
5 | 5 | from loguru import logger |
|
13 | 13 | TransformersDatasetRequestGenerator, |
14 | 14 | ) |
15 | 15 | from guidellm.request.base import RequestGenerator |
16 | | -from guidellm.utils import BenchmarkReportProgress |
| 16 | +from guidellm.utils import BenchmarkReportProgress, cli_params |
17 | 17 |
|
18 | 18 | __all__ = ["generate_benchmark_report"] |
19 | 19 |
|
|
120 | 120 | ) |
121 | 121 | @click.option( |
122 | 122 | "--max-requests", |
123 | | - type=int, |
| 123 | + type=cli_params.MAX_REQUESTS, |
124 | 124 | default=None, |
125 | 125 | help=( |
126 | 126 | "The maximum number of requests for each benchmark run. " |
@@ -161,7 +161,7 @@ def generate_benchmark_report_cli( |
161 | 161 | rate_type: ProfileGenerationMode, |
162 | 162 | rate: Optional[float], |
163 | 163 | max_seconds: Optional[int], |
164 | | - max_requests: Optional[int], |
| 164 | + max_requests: Union[Literal["dataset"], int, None], |
165 | 165 | output_path: str, |
166 | 166 | enable_continuous_refresh: bool, |
167 | 167 | ): |
@@ -194,7 +194,7 @@ def generate_benchmark_report( |
194 | 194 | rate_type: ProfileGenerationMode, |
195 | 195 | rate: Optional[float], |
196 | 196 | max_seconds: Optional[int], |
197 | | - max_requests: Optional[int], |
| 197 | + max_requests: Union[Literal["dataset"], int, None], |
198 | 198 | output_path: str, |
199 | 199 | cont_refresh_table: bool, |
200 | 200 | ) -> GuidanceReport: |
@@ -256,13 +256,18 @@ def generate_benchmark_report( |
256 | 256 | else: |
257 | 257 | raise ValueError(f"Unknown data type: {data_type}") |
258 | 258 |
|
| 259 | + if data_type == "emulated" and max_requests == "dataset": |
| 260 | + raise ValueError("Cannot use 'dataset' for emulated data") |
| 261 | + |
259 | 262 | # Create executor |
260 | 263 | executor = Executor( |
261 | 264 | backend=backend_inst, |
262 | 265 | request_generator=request_generator, |
263 | 266 | mode=rate_type, |
264 | 267 | rate=rate if rate_type in ("constant", "poisson") else None, |
265 | | - max_number=max_requests, |
| 268 | + max_number=( |
| 269 | + len(request_generator) if max_requests == "dataset" else max_requests |
| 270 | + ), |
266 | 271 | max_duration=max_seconds, |
267 | 272 | ) |
268 | 273 |
|
|
0 commit comments