Skip to content

Commit c896507

Browse files
committed
Handle required arg parsing with pydantic
1 parent 603a1f8 commit c896507

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/guidellm/__main__.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import get_args
66

77
import click
8+
from pydantic import ValidationError
89

910
from guidellm.backend import BackendType
1011
from guidellm.benchmark import ProfileType
@@ -67,13 +68,10 @@ def cli():
6768
"--scenario",
6869
type=str,
6970
default=None,
70-
help=(
71-
"TODO: A scenario or path to config"
72-
),
71+
help=("TODO: A scenario or path to config"),
7372
)
7473
@click.option(
7574
"--target",
76-
required=True,
7775
type=str,
7876
help="The target path for the backend to run benchmarks against. For example, http://localhost:8000",
7977
)
@@ -125,7 +123,6 @@ def cli():
125123
)
126124
@click.option(
127125
"--data",
128-
required=True,
129126
type=str,
130127
help=(
131128
"The HuggingFace dataset ID, a path to a HuggingFace dataset, "
@@ -153,7 +150,6 @@ def cli():
153150
)
154151
@click.option(
155152
"--rate-type",
156-
required=True,
157153
type=click.Choice(STRATEGY_PROFILE_CHOICES),
158154
help=(
159155
"The type of benchmark to run. "
@@ -305,12 +301,19 @@ def benchmark(
305301
random_seed=random_seed,
306302
)
307303

308-
# If a scenario file was specified read from it
309-
if scenario is None:
310-
_scenario = GenerativeTextScenario.model_validate(overrides)
311-
else:
312-
# TODO: Support pre-defined scenarios
313-
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
304+
try:
305+
# If a scenario file was specified read from it
306+
if scenario is None:
307+
_scenario = GenerativeTextScenario.model_validate(overrides)
308+
else:
309+
# TODO: Support pre-defined scenarios
310+
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
311+
except ValidationError as e:
312+
errs = e.errors(include_url=False, include_context=True, include_input=True)
313+
param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-")
314+
raise click.BadParameter(
315+
errs[0]["msg"], ctx=click_ctx, param_hint=param_name
316+
) from e
314317

315318
asyncio.run(
316319
benchmark_with_scenario(

0 commit comments

Comments
 (0)