|
5 | 5 | from typing import get_args |
6 | 6 |
|
7 | 7 | import click |
| 8 | +from pydantic import ValidationError |
8 | 9 |
|
9 | 10 | from guidellm.backend import BackendType |
10 | 11 | from guidellm.benchmark import ProfileType |
@@ -67,13 +68,10 @@ def cli(): |
67 | 68 | "--scenario", |
68 | 69 | type=str, |
69 | 70 | default=None, |
70 | | - help=( |
71 | | - "TODO: A scenario or path to config" |
72 | | - ), |
| 71 | + help=("TODO: A scenario or path to config"), |
73 | 72 | ) |
74 | 73 | @click.option( |
75 | 74 | "--target", |
76 | | - required=True, |
77 | 75 | type=str, |
78 | 76 | help="The target path for the backend to run benchmarks against. For example, http://localhost:8000", |
79 | 77 | ) |
@@ -125,7 +123,6 @@ def cli(): |
125 | 123 | ) |
126 | 124 | @click.option( |
127 | 125 | "--data", |
128 | | - required=True, |
129 | 126 | type=str, |
130 | 127 | help=( |
131 | 128 | "The HuggingFace dataset ID, a path to a HuggingFace dataset, " |
@@ -153,7 +150,6 @@ def cli(): |
153 | 150 | ) |
154 | 151 | @click.option( |
155 | 152 | "--rate-type", |
156 | | - required=True, |
157 | 153 | type=click.Choice(STRATEGY_PROFILE_CHOICES), |
158 | 154 | help=( |
159 | 155 | "The type of benchmark to run. " |
@@ -305,12 +301,19 @@ def benchmark( |
305 | 301 | random_seed=random_seed, |
306 | 302 | ) |
307 | 303 |
|
308 | | - # If a scenario file was specified read from it |
309 | | - if scenario is None: |
310 | | - _scenario = GenerativeTextScenario.model_validate(overrides) |
311 | | - else: |
312 | | - # TODO: Support pre-defined scenarios |
313 | | - _scenario = GenerativeTextScenario.from_file(scenario, overrides) |
| 304 | + try: |
| 305 | + # If a scenario file was specified read from it |
| 306 | + if scenario is None: |
| 307 | + _scenario = GenerativeTextScenario.model_validate(overrides) |
| 308 | + else: |
| 309 | + # TODO: Support pre-defined scenarios |
| 310 | + _scenario = GenerativeTextScenario.from_file(scenario, overrides) |
| 311 | + except ValidationError as e: |
| 312 | + errs = e.errors(include_url=False, include_context=True, include_input=True) |
| 313 | + param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-") |
| 314 | + raise click.BadParameter( |
| 315 | + errs[0]["msg"], ctx=click_ctx, param_hint=param_name |
| 316 | + ) from e |
314 | 317 |
|
315 | 318 | asyncio.run( |
316 | 319 | benchmark_with_scenario( |
|
0 commit comments