Skip to content

Commit c8db1be

Browse files
committed
Add a helper method to get scenario defaults
1 parent 511e9f2 commit c8db1be

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

src/guidellm/__main__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import codecs
33
import json
44
from pathlib import Path
5-
from typing import get_args
5+
from typing import Any, get_args
66

77
import click
88
from pydantic import ValidationError
@@ -29,7 +29,7 @@ def parse_json(ctx, param, value): # noqa: ARG001
2929
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
3030

3131

32-
def set_if_not_default(ctx: click.Context, **kwargs):
32+
def set_if_not_default(ctx: click.Context, **kwargs) -> dict[str, Any]:
3333
"""
3434
Set the value of a click option if it is not the default value.
3535
This is useful for setting options that are not None by default.
@@ -68,20 +68,20 @@ def cli():
6868
"The type of backend to use to run requests against. Defaults to 'openai_http'."
6969
f" Supported types: {', '.join(get_args(BackendType))}"
7070
),
71-
default=GenerativeTextScenario.model_fields["backend_type"].default,
71+
default=GenerativeTextScenario.get_default("backend_type"),
7272
)
7373
@click.option(
7474
"--backend-args",
7575
callback=parse_json,
76-
default=GenerativeTextScenario.model_fields["backend_args"].default,
76+
default=GenerativeTextScenario.get_default("backend_args"),
7777
help=(
7878
"A JSON string containing any arguments to pass to the backend as a "
7979
"dict with **kwargs."
8080
),
8181
)
8282
@click.option(
8383
"--model",
84-
default=GenerativeTextScenario.model_fields["model"].default,
84+
default=GenerativeTextScenario.get_default("model"),
8585
type=str,
8686
help=(
8787
"The ID of the model to benchmark within the backend. "
@@ -90,7 +90,7 @@ def cli():
9090
)
9191
@click.option(
9292
"--processor",
93-
default=GenerativeTextScenario.model_fields["processor"].default,
93+
default=GenerativeTextScenario.get_default("processor"),
9494
type=str,
9595
help=(
9696
"The processor or tokenizer to use to calculate token counts for statistics "
@@ -100,7 +100,7 @@ def cli():
100100
)
101101
@click.option(
102102
"--processor-args",
103-
default=GenerativeTextScenario.model_fields["processor_args"].default,
103+
default=GenerativeTextScenario.get_default("processor_args"),
104104
callback=parse_json,
105105
help=(
106106
"A JSON string containing any arguments to pass to the processor constructor "
@@ -118,7 +118,7 @@ def cli():
118118
)
119119
@click.option(
120120
"--data-args",
121-
default=GenerativeTextScenario.model_fields["data_args"].default,
121+
default=GenerativeTextScenario.get_default("data_args"),
122122
callback=parse_json,
123123
help=(
124124
"A JSON string containing any arguments to pass to the dataset creation "
@@ -127,7 +127,7 @@ def cli():
127127
)
128128
@click.option(
129129
"--data-sampler",
130-
default=GenerativeTextScenario.model_fields["data_sampler"].default,
130+
default=GenerativeTextScenario.get_default("data_sampler"),
131131
type=click.Choice(["random"]),
132132
help=(
133133
"The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -144,7 +144,7 @@ def cli():
144144
)
145145
@click.option(
146146
"--rate",
147-
default=GenerativeTextScenario.model_fields["rate"].default,
147+
default=GenerativeTextScenario.get_default("rate"),
148148
help=(
149149
"The rates to run the benchmark at. "
150150
"Can be a single number or a comma-separated list of numbers. "
@@ -157,7 +157,7 @@ def cli():
157157
@click.option(
158158
"--max-seconds",
159159
type=float,
160-
default=GenerativeTextScenario.model_fields["max_seconds"].default,
160+
default=GenerativeTextScenario.get_default("max_seconds"),
161161
help=(
162162
"The maximum number of seconds each benchmark can run for. "
163163
"If None, will run until max_requests or the data is exhausted."
@@ -166,7 +166,7 @@ def cli():
166166
@click.option(
167167
"--max-requests",
168168
type=int,
169-
default=GenerativeTextScenario.model_fields["max_requests"].default,
169+
default=GenerativeTextScenario.get_default("max_requests"),
170170
help=(
171171
"The maximum number of requests each benchmark can run for. "
172172
"If None, will run until max_seconds or the data is exhausted."
@@ -175,7 +175,7 @@ def cli():
175175
@click.option(
176176
"--warmup-percent",
177177
type=float,
178-
default=GenerativeTextScenario.model_fields["warmup_percent"].default,
178+
default=GenerativeTextScenario.get_default("warmup_percent"),
179179
help=(
180180
"The percent of the benchmark (based on max-seconds, max-requets, "
181181
"or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -185,7 +185,7 @@ def cli():
185185
@click.option(
186186
"--cooldown-percent",
187187
type=float,
188-
default=GenerativeTextScenario.model_fields["cooldown_percent"].default,
188+
default=GenerativeTextScenario.get_default("cooldown_percent"),
189189
help=(
190190
"The percent of the benchmark (based on max-seconds, max-requets, or lenth "
191191
"of dataset) to run as a cooldown and not include in the final results. "
@@ -230,11 +230,11 @@ def cli():
230230
"The number of samples to save in the output file. "
231231
"If None (default), will save all samples."
232232
),
233-
default=GenerativeTextScenario.model_fields["output_sampling"].default,
233+
default=GenerativeTextScenario.get_default("output_sampling"),
234234
)
235235
@click.option(
236236
"--random-seed",
237-
default=GenerativeTextScenario.model_fields["random_seed"].default,
237+
default=GenerativeTextScenario.get_default("random_seed"),
238238
type=int,
239239
help="The random seed to use for benchmarking to ensure reproducibility.",
240240
)

src/guidellm/benchmark/entrypoints.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from guidellm.request import GenerativeRequestLoader
2020
from guidellm.scheduler import StrategyType
2121

22-
type benchmark_type = Literal["generative_text"]
23-
2422

2523
async def benchmark_with_scenario(scenario: Scenario, **kwargs):
2624
"""

src/guidellm/benchmark/scenario.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def parse_float_list(value: Union[str, float, list[float]]) -> list[float]:
4141
class Scenario(StandardBaseModel):
4242
target: str
4343

44+
@classmethod
45+
def get_default(cls: type[T], field: str) -> Any:
46+
"""Get default values for model fields"""
47+
return cls.model_fields[field].default
48+
4449
@classmethod
4550
def from_file(
4651
cls: type[T], filename: Union[str, Path], overrides: Optional[dict] = None

0 commit comments

Comments
 (0)