Skip to content

Commit cf31319

Browse files
committed
Move cli helpers to separate file and add click Union type back
1 parent bc937b5 commit cf31319

File tree

2 files changed

+76
-32
lines changed

2 files changed

+76
-32
lines changed

src/guidellm/__main__.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import asyncio
22
import codecs
3-
import json
43
from pathlib import Path
5-
from typing import Any, get_args
4+
from typing import get_args
65

76
import click
87
from pydantic import ValidationError
@@ -14,34 +13,13 @@
1413
from guidellm.config import print_config
1514
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
1615
from guidellm.scheduler import StrategyType
16+
from guidellm.utils import cli as cli_tools
1717

1818
STRATEGY_PROFILE_CHOICES = set(
1919
list(get_args(ProfileType)) + list(get_args(StrategyType))
2020
)
2121

2222

23-
def parse_json(ctx, param, value): # noqa: ARG001
24-
if value is None:
25-
return None
26-
try:
27-
return json.loads(value)
28-
except json.JSONDecodeError as err:
29-
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
30-
31-
32-
def set_if_not_default(ctx: click.Context, **kwargs) -> dict[str, Any]:
33-
"""
34-
Set the value of a click option if it is not the default value.
35-
This is useful for setting options that are not None by default.
36-
"""
37-
values = {}
38-
for k, v in kwargs.items():
39-
if ctx.get_parameter_source(k) != click.core.ParameterSource.DEFAULT:
40-
values[k] = v
41-
42-
return values
43-
44-
4523
@click.group()
4624
def cli():
4725
pass
@@ -52,7 +30,10 @@ def cli():
5230
)
5331
@click.option(
5432
"--scenario",
55-
type=str,
33+
type=cli_tools.Union(
34+
click.Path(exists=True, readable=True, file_okay=True, dir_okay=False),
35+
click.STRING
36+
),
5637
default=None,
5738
help=("TODO: A scenario or path to config"),
5839
)
@@ -72,7 +53,7 @@ def cli():
7253
)
7354
@click.option(
7455
"--backend-args",
75-
callback=parse_json,
56+
callback=cli_tools.parse_json,
7657
default=GenerativeTextScenario.get_default("backend_args"),
7758
help=(
7859
"A JSON string containing any arguments to pass to the backend as a "
@@ -101,7 +82,7 @@ def cli():
10182
@click.option(
10283
"--processor-args",
10384
default=GenerativeTextScenario.get_default("processor_args"),
104-
callback=parse_json,
85+
callback=cli_tools.parse_json,
10586
help=(
10687
"A JSON string containing any arguments to pass to the processor constructor "
10788
"as a dict with **kwargs."
@@ -119,7 +100,7 @@ def cli():
119100
@click.option(
120101
"--data-args",
121102
default=GenerativeTextScenario.get_default("data_args"),
122-
callback=parse_json,
103+
callback=cli_tools.parse_json,
123104
help=(
124105
"A JSON string containing any arguments to pass to the dataset creation "
125106
"as a dict with **kwargs."
@@ -220,7 +201,7 @@ def cli():
220201
)
221202
@click.option(
222203
"--output-extras",
223-
callback=parse_json,
204+
callback=cli_tools.parse_json,
224205
help="A JSON string of extra data to save with the output benchmarks",
225206
)
226207
@click.option(
@@ -265,7 +246,7 @@ def benchmark(
265246
):
266247
click_ctx = click.get_current_context()
267248

268-
overrides = set_if_not_default(
249+
overrides = cli_tools.set_if_not_default(
269250
click_ctx,
270251
target=target,
271252
backend_type=backend_type,
@@ -370,15 +351,15 @@ def preprocess():
370351
@click.option(
371352
"--processor-args",
372353
default=None,
373-
callback=parse_json,
354+
callback=cli_tools.parse_json,
374355
help=(
375356
"A JSON string containing any arguments to pass to the processor constructor "
376357
"as a dict with **kwargs."
377358
),
378359
)
379360
@click.option(
380361
"--data-args",
381-
callback=parse_json,
362+
callback=cli_tools.parse_json,
382363
help=(
383364
"A JSON string containing any arguments to pass to the dataset creation "
384365
"as a dict with **kwargs."

src/guidellm/utils/cli.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import json
2+
from typing import Any
3+
4+
import click
5+
6+
7+
def parse_json(ctx, param, value): # noqa: ARG001
8+
if value is None:
9+
return None
10+
try:
11+
return json.loads(value)
12+
except json.JSONDecodeError as err:
13+
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
14+
15+
16+
def set_if_not_default(ctx: click.Context, **kwargs) -> dict[str, Any]:
17+
"""
18+
Set the value of a click option if it is not the default value.
19+
This is useful for setting options that are not None by default.
20+
"""
21+
values = {}
22+
for k, v in kwargs.items():
23+
if ctx.get_parameter_source(k) != click.core.ParameterSource.DEFAULT:
24+
values[k] = v
25+
26+
return values
27+
28+
29+
class Union(click.ParamType):
30+
"""
31+
A custom click parameter type that allows for multiple types to be accepted.
32+
"""
33+
34+
def __init__(self, *types: click.ParamType):
35+
self.types = types
36+
self.name = "".join(t.name for t in types)
37+
38+
def convert(self, value, param, ctx):
39+
fails = []
40+
for t in self.types:
41+
try:
42+
return t.convert(value, param, ctx)
43+
except click.BadParameter as e:
44+
fails.append(str(e))
45+
continue
46+
47+
self.fail("; ".join(fails) or f"Invalid value: {value}") # noqa: RET503
48+
49+
50+
def get_metavar(self, param: click.Parameter) -> str:
51+
def get_choices(t: click.ParamType) -> str:
52+
meta = t.get_metavar(param)
53+
return meta if meta is not None else t.name
54+
55+
# Get the choices for each type in the union.
56+
choices_str = "|".join(map(get_choices, self.types))
57+
58+
# Use curly braces to indicate a required argument.
59+
if param.required and param.param_type_name == "argument":
60+
return f"{{{choices_str}}}"
61+
62+
# Use square braces to indicate an option or optional argument.
63+
return f"[{choices_str}]"

0 commit comments

Comments
 (0)