77import click
88
99from guidellm .backend import BackendType
10- from guidellm .benchmark import ProfileType , benchmark_generative_text
10+ from guidellm .benchmark import ProfileType
11+ from guidellm .benchmark .entrypoints import benchmark_with_scenario
12+ from guidellm .benchmark .scenario import GenerativeTextScenario
1113from guidellm .config import print_config
1214from guidellm .preprocess .dataset import ShortPromptStrategy , process_dataset
1315from guidellm .scheduler import StrategyType
@@ -40,6 +42,19 @@ def parse_number_str(ctx, param, value): # noqa: ARG001
4042 ) from err
4143
4244
45+ def set_if_not_default (ctx : click .Context , ** kwargs ):
46+ """
47+ Set the value of a click option if it is not the default value.
48+ This is useful for setting options that are not None by default.
49+ """
50+ values = {}
51+ for k , v in kwargs .items ():
52+ if ctx .get_parameter_source (k ) != click .core .ParameterSource .DEFAULT :
53+ values [k ] = v
54+
55+ return values
56+
57+
4358@click .group ()
4459def cli ():
4560 pass
@@ -48,6 +63,14 @@ def cli():
4863@cli .command (
4964 help = "Run a benchmark against a generative model using the specified arguments."
5065)
66+ @click .option (
67+ "--scenario" ,
68+ type = str ,
69+ default = None ,
70+ help = (
71+ "TODO: A scenario or path to config"
72+ ),
73+ )
5174@click .option (
5275 "--target" ,
5376 required = True ,
@@ -61,20 +84,20 @@ def cli():
6184 "The type of backend to use to run requests against. Defaults to 'openai_http'."
6285 f" Supported types: { ', ' .join (get_args (BackendType ))} "
6386 ),
64- default = "openai_http" ,
87+ default = GenerativeTextScenario . backend_type ,
6588)
6689@click .option (
6790 "--backend-args" ,
6891 callback = parse_json ,
69- default = None ,
92+ default = GenerativeTextScenario . backend_args ,
7093 help = (
7194 "A JSON string containing any arguments to pass to the backend as a "
7295 "dict with **kwargs."
7396 ),
7497)
7598@click .option (
7699 "--model" ,
77- default = None ,
100+ default = GenerativeTextScenario . model ,
78101 type = str ,
79102 help = (
80103 "The ID of the model to benchmark within the backend. "
@@ -83,7 +106,7 @@ def cli():
83106)
84107@click .option (
85108 "--processor" ,
86- default = None ,
109+ default = GenerativeTextScenario . processor ,
87110 type = str ,
88111 help = (
89112 "The processor or tokenizer to use to calculate token counts for statistics "
@@ -93,7 +116,7 @@ def cli():
93116)
94117@click .option (
95118 "--processor-args" ,
96- default = None ,
119+ default = GenerativeTextScenario . processor_args ,
97120 callback = parse_json ,
98121 help = (
99122 "A JSON string containing any arguments to pass to the processor constructor "
@@ -112,6 +135,7 @@ def cli():
112135)
113136@click .option (
114137 "--data-args" ,
138+ default = GenerativeTextScenario .data_args ,
115139 callback = parse_json ,
116140 help = (
117141 "A JSON string containing any arguments to pass to the dataset creation "
@@ -120,7 +144,7 @@ def cli():
120144)
121145@click .option (
122146 "--data-sampler" ,
123- default = None ,
147+ default = GenerativeTextScenario . data_sampler ,
124148 type = click .Choice (["random" ]),
125149 help = (
126150 "The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -138,7 +162,7 @@ def cli():
138162)
139163@click .option (
140164 "--rate" ,
141- default = None ,
165+ default = GenerativeTextScenario . rate ,
142166 callback = parse_number_str ,
143167 help = (
144168 "The rates to run the benchmark at. "
@@ -152,6 +176,7 @@ def cli():
152176@click .option (
153177 "--max-seconds" ,
154178 type = float ,
179+ default = GenerativeTextScenario .max_seconds ,
155180 help = (
156181 "The maximum number of seconds each benchmark can run for. "
157182 "If None, will run until max_requests or the data is exhausted."
@@ -160,6 +185,7 @@ def cli():
160185@click .option (
161186 "--max-requests" ,
162187 type = int ,
188+ default = GenerativeTextScenario .max_requests ,
163189 help = (
164190 "The maximum number of requests each benchmark can run for. "
165191 "If None, will run until max_seconds or the data is exhausted."
@@ -168,7 +194,7 @@ def cli():
168194@click .option (
169195 "--warmup-percent" ,
170196 type = float ,
171- default = None ,
197+ default = GenerativeTextScenario . warmup_percent ,
172198 help = (
173199 "The percent of the benchmark (based on max-seconds, max-requets, "
174200 "or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -178,6 +204,7 @@ def cli():
178204@click .option (
179205 "--cooldown-percent" ,
180206 type = float ,
207+ default = GenerativeTextScenario .cooldown_percent ,
181208 help = (
182209 "The percent of the benchmark (based on max-seconds, max-requets, or lenth "
183210 "of dataset) to run as a cooldown and not include in the final results. "
@@ -187,16 +214,19 @@ def cli():
187214@click .option (
188215 "--disable-progress" ,
189216 is_flag = True ,
217+ default = not GenerativeTextScenario .show_progress ,
190218 help = "Set this flag to disable progress updates to the console" ,
191219)
192220@click .option (
193221 "--display-scheduler-stats" ,
194222 is_flag = True ,
223+ default = GenerativeTextScenario .show_progress_scheduler_stats ,
195224 help = "Set this flag to display stats for the processes running the benchmarks" ,
196225)
197226@click .option (
198227 "--disable-console-outputs" ,
199228 is_flag = True ,
229+ default = not GenerativeTextScenario .output_console ,
200230 help = "Set this flag to disable console output" ,
201231)
202232@click .option (
@@ -213,6 +243,7 @@ def cli():
213243@click .option (
214244 "--output-extras" ,
215245 callback = parse_json ,
246+ default = GenerativeTextScenario .output_extras ,
216247 help = "A JSON string of extra data to save with the output benchmarks" ,
217248)
218249@click .option (
@@ -222,15 +253,16 @@ def cli():
222253 "The number of samples to save in the output file. "
223254 "If None (default), will save all samples."
224255 ),
225- default = None ,
256+ default = GenerativeTextScenario . output_sampling ,
226257)
227258@click .option (
228259 "--random-seed" ,
229- default = 42 ,
260+ default = GenerativeTextScenario . random_seed ,
230261 type = int ,
231262 help = "The random seed to use for benchmarking to ensure reproducibility." ,
232263)
233264def benchmark (
265+ scenario ,
234266 target ,
235267 backend_type ,
236268 backend_args ,
@@ -254,30 +286,48 @@ def benchmark(
254286 output_sampling ,
255287 random_seed ,
256288):
289+ click_ctx = click .get_current_context ()
290+
291+ # If a scenario file was specified read from it
292+ # TODO: This should probably be a factory method
293+ if scenario is None :
294+ _scenario = {}
295+ else :
296+ # TODO: Support pre-defined scenarios
297+ # TODO: Support other formats
298+ with Path (scenario ).open () as f :
299+ _scenario = json .load (f )
300+
301+ # If any command line arguments are specified, override the scenario
302+ _scenario .update (set_if_not_default (
303+ click_ctx ,
304+ target = target ,
305+ backend_type = backend_type ,
306+ backend_args = backend_args ,
307+ model = model ,
308+ processor = processor ,
309+ processor_args = processor_args ,
310+ data = data ,
311+ data_args = data_args ,
312+ data_sampler = data_sampler ,
313+ rate_type = rate_type ,
314+ rate = rate ,
315+ max_seconds = max_seconds ,
316+ max_requests = max_requests ,
317+ warmup_percent = warmup_percent ,
318+ cooldown_percent = cooldown_percent ,
319+ show_progress = not disable_progress ,
320+ show_progress_scheduler_stats = display_scheduler_stats ,
321+ output_console = not disable_console_outputs ,
322+ output_path = output_path ,
323+ output_extras = output_extras ,
324+ output_sampling = output_sampling ,
325+ random_seed = random_seed ,
326+ ))
327+
257328 asyncio .run (
258- benchmark_generative_text (
259- target = target ,
260- backend_type = backend_type ,
261- backend_args = backend_args ,
262- model = model ,
263- processor = processor ,
264- processor_args = processor_args ,
265- data = data ,
266- data_args = data_args ,
267- data_sampler = data_sampler ,
268- rate_type = rate_type ,
269- rate = rate ,
270- max_seconds = max_seconds ,
271- max_requests = max_requests ,
272- warmup_percent = warmup_percent ,
273- cooldown_percent = cooldown_percent ,
274- show_progress = not disable_progress ,
275- show_progress_scheduler_stats = display_scheduler_stats ,
276- output_console = not disable_console_outputs ,
277- output_path = output_path ,
278- output_extras = output_extras ,
279- output_sampling = output_sampling ,
280- random_seed = random_seed ,
329+ benchmark_with_scenario (
330+ scenario = GenerativeTextScenario (** _scenario )
281331 )
282332 )
283333
0 commit comments