Skip to content

Commit dd219f1

Browse files
authored
[Bugfix] Switch --rates CLI arg to handle a comma separated list of values (#433)
## Summary Since the refactor, --rates will not allow a comma separated list of values like --rates=1,2,3,4 ## Details This PR re-enables this feature, where a user can specify a static list of rates to sweep through. ## Test Plan - Testing locally ## Related Issues - None --- - [X] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [ ] Includes AI-assisted code completion - [X] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents b05b0e7 + 9d98291 commit dd219f1

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/guidellm/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def benchmark():
156156
)
157157
@click.option(
158158
"--rate",
159-
type=float,
160-
multiple=True,
159+
type=str,
160+
callback=cli_tools.parse_list_floats,
161+
multiple=False,
161162
default=BenchmarkGenerativeTextArgs.get_default("rate"),
162163
help=(
163164
"Benchmark rate(s) to test. Meaning depends on profile: "
@@ -383,7 +384,7 @@ def run(**kwargs):
383384
kwargs.get("data_args"), default=[], simplify_single=False
384385
)
385386
kwargs["rate"] = cli_tools.format_list_arg(
386-
kwargs.get("rate"), default=None, simplify_single=True
387+
kwargs.get("rate"), default=None, simplify_single=False
387388
)
388389

389390
disable_console_outputs = kwargs.pop("disable_console_outputs", False)

src/guidellm/utils/cli.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,34 @@
33

44
import click
55

6-
__all__ = ["Union", "format_list_arg", "parse_json", "set_if_not_default"]
6+
__all__ = [
7+
"Union",
8+
"format_list_arg",
9+
"parse_json",
10+
"parse_list_floats",
11+
"set_if_not_default",
12+
]
713

814

15+
def parse_list_floats(ctx, param, value): # noqa: ARG001
16+
"""
17+
Callback to parse a comma-separated string into a list of floats.
18+
"""
19+
# This callback only runs if the --rate option is provided by the user.
20+
# If it's not, 'value' will be None, and Click will use the 'default'.
21+
if value is None:
22+
return None # Keep the default
23+
24+
try:
25+
# Split by comma, strip any whitespace, and convert to float
26+
return [float(item.strip()) for item in value.split(",")]
27+
except ValueError as e:
28+
# Raise a Click error if any part isn't a valid float
29+
raise click.BadParameter(
30+
f"Value '{value}' is not a valid comma-separated list "
31+
f"of floats/ints. Error: {e}"
32+
) from e
33+
934
def parse_json(ctx, param, value): # noqa: ARG001
1035
if value is None or value == [None]:
1136
return None

0 commit comments

Comments
 (0)