Skip to content

Commit 206fc79

Browse files
committed
Second Scenario implmentation attempt
1 parent db7b534 commit 206fc79

File tree

3 files changed

+160
-34
lines changed

3 files changed

+160
-34
lines changed

src/guidellm/__main__.py

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import click
88

99
from guidellm.backend import BackendType
10-
from guidellm.benchmark import ProfileType, benchmark_generative_text
10+
from guidellm.benchmark import ProfileType
11+
from guidellm.benchmark.entrypoints import benchmark_with_scenario
12+
from guidellm.benchmark.scenario import GenerativeTextScenario
1113
from guidellm.config import print_config
1214
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
1315
from guidellm.scheduler import StrategyType
@@ -40,6 +42,19 @@ def parse_number_str(ctx, param, value): # noqa: ARG001
4042
) from err
4143

4244

45+
def set_if_not_default(ctx: click.Context, **kwargs):
46+
"""
47+
Set the value of a click option if it is not the default value.
48+
This is useful for setting options that are not None by default.
49+
"""
50+
values = {}
51+
for k, v in kwargs.items():
52+
if ctx.get_parameter_source(k) != click.core.ParameterSource.DEFAULT:
53+
values[k] = v
54+
55+
return values
56+
57+
4358
@click.group()
4459
def cli():
4560
pass
@@ -48,6 +63,14 @@ def cli():
4863
@cli.command(
4964
help="Run a benchmark against a generative model using the specified arguments."
5065
)
66+
@click.option(
67+
"--scenario",
68+
type=str,
69+
default=None,
70+
help=(
71+
"TODO: A scenario or path to config"
72+
),
73+
)
5174
@click.option(
5275
"--target",
5376
required=True,
@@ -61,20 +84,20 @@ def cli():
6184
"The type of backend to use to run requests against. Defaults to 'openai_http'."
6285
f" Supported types: {', '.join(get_args(BackendType))}"
6386
),
64-
default="openai_http",
87+
default=GenerativeTextScenario.backend_type,
6588
)
6689
@click.option(
6790
"--backend-args",
6891
callback=parse_json,
69-
default=None,
92+
default=GenerativeTextScenario.backend_args,
7093
help=(
7194
"A JSON string containing any arguments to pass to the backend as a "
7295
"dict with **kwargs."
7396
),
7497
)
7598
@click.option(
7699
"--model",
77-
default=None,
100+
default=GenerativeTextScenario.model,
78101
type=str,
79102
help=(
80103
"The ID of the model to benchmark within the backend. "
@@ -83,7 +106,7 @@ def cli():
83106
)
84107
@click.option(
85108
"--processor",
86-
default=None,
109+
default=GenerativeTextScenario.processor,
87110
type=str,
88111
help=(
89112
"The processor or tokenizer to use to calculate token counts for statistics "
@@ -93,7 +116,7 @@ def cli():
93116
)
94117
@click.option(
95118
"--processor-args",
96-
default=None,
119+
default=GenerativeTextScenario.processor_args,
97120
callback=parse_json,
98121
help=(
99122
"A JSON string containing any arguments to pass to the processor constructor "
@@ -112,6 +135,7 @@ def cli():
112135
)
113136
@click.option(
114137
"--data-args",
138+
default=GenerativeTextScenario.data_args,
115139
callback=parse_json,
116140
help=(
117141
"A JSON string containing any arguments to pass to the dataset creation "
@@ -120,7 +144,7 @@ def cli():
120144
)
121145
@click.option(
122146
"--data-sampler",
123-
default=None,
147+
default=GenerativeTextScenario.data_sampler,
124148
type=click.Choice(["random"]),
125149
help=(
126150
"The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -138,7 +162,7 @@ def cli():
138162
)
139163
@click.option(
140164
"--rate",
141-
default=None,
165+
default=GenerativeTextScenario.rate,
142166
callback=parse_number_str,
143167
help=(
144168
"The rates to run the benchmark at. "
@@ -152,6 +176,7 @@ def cli():
152176
@click.option(
153177
"--max-seconds",
154178
type=float,
179+
default=GenerativeTextScenario.max_seconds,
155180
help=(
156181
"The maximum number of seconds each benchmark can run for. "
157182
"If None, will run until max_requests or the data is exhausted."
@@ -160,6 +185,7 @@ def cli():
160185
@click.option(
161186
"--max-requests",
162187
type=int,
188+
default=GenerativeTextScenario.max_requests,
163189
help=(
164190
"The maximum number of requests each benchmark can run for. "
165191
"If None, will run until max_seconds or the data is exhausted."
@@ -168,7 +194,7 @@ def cli():
168194
@click.option(
169195
"--warmup-percent",
170196
type=float,
171-
default=None,
197+
default=GenerativeTextScenario.warmup_percent,
172198
help=(
173199
"The percent of the benchmark (based on max-seconds, max-requets, "
174200
"or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -178,6 +204,7 @@ def cli():
178204
@click.option(
179205
"--cooldown-percent",
180206
type=float,
207+
default=GenerativeTextScenario.cooldown_percent,
181208
help=(
182209
"The percent of the benchmark (based on max-seconds, max-requets, or lenth "
183210
"of dataset) to run as a cooldown and not include in the final results. "
@@ -187,16 +214,19 @@ def cli():
187214
@click.option(
188215
"--disable-progress",
189216
is_flag=True,
217+
default=not GenerativeTextScenario.show_progress,
190218
help="Set this flag to disable progress updates to the console",
191219
)
192220
@click.option(
193221
"--display-scheduler-stats",
194222
is_flag=True,
223+
default=GenerativeTextScenario.show_progress_scheduler_stats,
195224
help="Set this flag to display stats for the processes running the benchmarks",
196225
)
197226
@click.option(
198227
"--disable-console-outputs",
199228
is_flag=True,
229+
default=not GenerativeTextScenario.output_console,
200230
help="Set this flag to disable console output",
201231
)
202232
@click.option(
@@ -213,6 +243,7 @@ def cli():
213243
@click.option(
214244
"--output-extras",
215245
callback=parse_json,
246+
default=GenerativeTextScenario.output_extras,
216247
help="A JSON string of extra data to save with the output benchmarks",
217248
)
218249
@click.option(
@@ -222,15 +253,16 @@ def cli():
222253
"The number of samples to save in the output file. "
223254
"If None (default), will save all samples."
224255
),
225-
default=None,
256+
default=GenerativeTextScenario.output_sampling,
226257
)
227258
@click.option(
228259
"--random-seed",
229-
default=42,
260+
default=GenerativeTextScenario.random_seed,
230261
type=int,
231262
help="The random seed to use for benchmarking to ensure reproducibility.",
232263
)
233264
def benchmark(
265+
scenario,
234266
target,
235267
backend_type,
236268
backend_args,
@@ -254,30 +286,48 @@ def benchmark(
254286
output_sampling,
255287
random_seed,
256288
):
289+
click_ctx = click.get_current_context()
290+
291+
# If a scenario file was specified read from it
292+
# TODO: This should probably be a factory method
293+
if scenario is None:
294+
_scenario = {}
295+
else:
296+
# TODO: Support pre-defined scenarios
297+
# TODO: Support other formats
298+
with Path(scenario).open() as f:
299+
_scenario = json.load(f)
300+
301+
# If any command line arguments are specified, override the scenario
302+
_scenario.update(set_if_not_default(
303+
click_ctx,
304+
target=target,
305+
backend_type=backend_type,
306+
backend_args=backend_args,
307+
model=model,
308+
processor=processor,
309+
processor_args=processor_args,
310+
data=data,
311+
data_args=data_args,
312+
data_sampler=data_sampler,
313+
rate_type=rate_type,
314+
rate=rate,
315+
max_seconds=max_seconds,
316+
max_requests=max_requests,
317+
warmup_percent=warmup_percent,
318+
cooldown_percent=cooldown_percent,
319+
show_progress=not disable_progress,
320+
show_progress_scheduler_stats=display_scheduler_stats,
321+
output_console=not disable_console_outputs,
322+
output_path=output_path,
323+
output_extras=output_extras,
324+
output_sampling=output_sampling,
325+
random_seed=random_seed,
326+
))
327+
257328
asyncio.run(
258-
benchmark_generative_text(
259-
target=target,
260-
backend_type=backend_type,
261-
backend_args=backend_args,
262-
model=model,
263-
processor=processor,
264-
processor_args=processor_args,
265-
data=data,
266-
data_args=data_args,
267-
data_sampler=data_sampler,
268-
rate_type=rate_type,
269-
rate=rate,
270-
max_seconds=max_seconds,
271-
max_requests=max_requests,
272-
warmup_percent=warmup_percent,
273-
cooldown_percent=cooldown_percent,
274-
show_progress=not disable_progress,
275-
show_progress_scheduler_stats=display_scheduler_stats,
276-
output_console=not disable_console_outputs,
277-
output_path=output_path,
278-
output_extras=output_extras,
279-
output_sampling=output_sampling,
280-
random_seed=random_seed,
329+
benchmark_with_scenario(
330+
scenario=GenerativeTextScenario(**_scenario)
281331
)
282332
)
283333

src/guidellm/benchmark/entrypoints.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,23 @@
1515
)
1616
from guidellm.benchmark.profile import ProfileType, create_profile
1717
from guidellm.benchmark.progress import GenerativeTextBenchmarkerProgressDisplay
18+
from guidellm.benchmark.scenario import GenerativeTextScenario, Scenario
1819
from guidellm.request import GenerativeRequestLoader
1920
from guidellm.scheduler import StrategyType
2021

22+
type benchmark_type = Literal["generative_text"]
23+
24+
25+
async def benchmark_with_scenario(scenario: Scenario, **kwargs):
26+
"""
27+
Run a benchmark using a scenario and specify any extra arguments
28+
"""
29+
30+
if isinstance(scenario, GenerativeTextScenario):
31+
return await benchmark_generative_text(**vars(scenario), **kwargs)
32+
else:
33+
raise ValueError(f"Unsupported Scenario type {type(scenario)}")
34+
2135

2236
async def benchmark_generative_text(
2337
target: str,

src/guidellm/benchmark/scenario.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from collections.abc import Iterable
2+
from pathlib import Path
3+
from typing import Any, Literal, Optional, Self, Union
4+
5+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
6+
from transformers.tokenization_utils_base import ( # type: ignore[import]
7+
PreTrainedTokenizerBase,
8+
)
9+
10+
from guidellm.backend.backend import BackendType
11+
from guidellm.benchmark.profile import ProfileType
12+
from guidellm.objects.pydantic import StandardBaseModel
13+
from guidellm.scheduler.strategy import StrategyType
14+
15+
__ALL__ = ["Scenario", "GenerativeTextScenario"]
16+
17+
18+
class Scenario(StandardBaseModel):
19+
target: str
20+
21+
def _update(self, **fields: Any) -> Self:
22+
for k, v in fields.items():
23+
if not hasattr(self, k):
24+
raise ValueError(f"Invalid field {k}")
25+
setattr(self, k, v)
26+
27+
return self
28+
29+
def update(self, **fields: Any) -> Self:
30+
return self._update(**{k: v for k, v in fields.items() if v is not None})
31+
32+
33+
class GenerativeTextScenario(Scenario):
34+
backend_type: BackendType = "openai_http"
35+
backend_args: Optional[dict[str, Any]] = None
36+
model: Optional[str] = None
37+
processor: Optional[Union[str, Path, PreTrainedTokenizerBase]] = None
38+
processor_args: Optional[dict[str, Any]] = None
39+
data: Union[
40+
str,
41+
Path,
42+
Iterable[Union[str, dict[str, Any]]],
43+
Dataset,
44+
DatasetDict,
45+
IterableDataset,
46+
IterableDatasetDict,
47+
]
48+
data_args: Optional[dict[str, Any]] = None
49+
data_sampler: Optional[Literal["random"]] = None
50+
rate_type: Union[StrategyType, ProfileType]
51+
rate: Optional[Union[int, float, list[Union[int, float]]]] = None
52+
max_seconds: Optional[float] = None
53+
max_requests: Optional[int] = None
54+
warmup_percent: Optional[float] = None
55+
cooldown_percent: Optional[float] = None
56+
show_progress: bool = True
57+
show_progress_scheduler_stats: bool = True
58+
output_console: bool = True
59+
output_path: Optional[Union[str, Path]] = None
60+
output_extras: Optional[dict[str, Any]] = None
61+
output_sampling: Optional[int] = None
62+
random_seed: int = 42

0 commit comments

Comments
 (0)