4141_endpoint_type_map = {"chat" : "v1/chat/completions" , "completions" : "v1/completions" }
4242
4343
44+ def _check_model_args (
45+ parser : argparse .ArgumentParser , args : argparse .Namespace
46+ ) -> argparse .Namespace :
47+ """
48+ Check if model name is provided.
49+ """
50+ if not args .subcommand and not args .model :
51+ parser .error ("The -m/--model option is required and cannot be empty." )
52+ return args
53+
54+
55+ def _check_compare_args (
56+ parser : argparse .ArgumentParser , args : argparse .Namespace
57+ ) -> argparse .Namespace :
58+ """
59+ Check compare subcommand args
60+ """
61+ if args .subcommand == "compare" :
62+ if not args .config and not args .files :
63+ parser .error ("Either the --config or --files option must be specified." )
64+ return args
65+
66+
4467def _check_conditional_args (
4568 parser : argparse .ArgumentParser , args : argparse .Namespace
4669) -> argparse .Namespace :
@@ -132,15 +155,6 @@ def _convert_str_to_enum_entry(args, option, enum):
132155 return args
133156
134157
135- ### Handlers ###
136-
137-
138- def handler (args , extra_args ):
139- from genai_perf .wrapper import Profiler
140-
141- Profiler .run (args = args , extra_args = extra_args )
142-
143-
144158### Parsers ###
145159
146160
@@ -286,7 +300,7 @@ def _add_endpoint_args(parser):
286300 "-m" ,
287301 "--model" ,
288302 type = str ,
289- required = True ,
303+ default = None ,
290304 help = f"The name of the model to benchmark." ,
291305 )
292306
@@ -437,6 +451,47 @@ def get_extra_inputs_as_dict(args: argparse.Namespace) -> dict:
437451 return request_inputs
438452
439453
454+ def _parse_compare_args (subparsers ) -> argparse .ArgumentParser :
455+ compare = subparsers .add_parser (
456+ "compare" ,
457+ description = "Subcommand to generate plots that compare multiple profile runs." ,
458+ )
459+ compare_group = compare .add_argument_group ("Compare" )
460+ mx_group = compare_group .add_mutually_exclusive_group (required = False )
461+ mx_group .add_argument (
462+ "--config" ,
463+ type = Path ,
464+ default = None ,
465+ help = "The path to the YAML file that specifies plot configurations for "
466+ "comparing multiple runs." ,
467+ )
468+ mx_group .add_argument (
469+ "-f" ,
470+ "--files" ,
471+ nargs = "+" ,
472+ default = [],
473+ help = "List of paths to the profile export JSON files. Users can specify "
474+ "this option instead of the `--config` option if they would like "
475+ "GenAI-Perf to generate default plots as well as initial YAML config file." ,
476+ )
477+ compare .set_defaults (func = compare_handler )
478+ return compare
479+
480+
481+ ### Handlers ###
482+
483+
484+ def profile_handler (args , extra_args ):
485+ from genai_perf .wrapper import Profiler
486+
487+ Profiler .run (args = args , extra_args = extra_args )
488+
489+
490+ def compare_handler (args : argparse .Namespace ):
491+ # TMA-1893: parse yaml file
492+ pass
493+
494+
440495### Entrypoint ###
441496
442497
@@ -448,7 +503,7 @@ def parse_args():
448503 description = "CLI to profile LLMs and Generative AI models with Perf Analyzer" ,
449504 formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
450505 )
451- parser .set_defaults (func = handler )
506+ parser .set_defaults (func = profile_handler )
452507
453508 # Conceptually group args for easier visualization
454509 _add_endpoint_args (parser )
@@ -457,6 +512,12 @@ def parse_args():
457512 _add_output_args (parser )
458513 _add_other_args (parser )
459514
515+ # Add subcommands
516+ subparsers = parser .add_subparsers (
517+ help = "List of subparser commands." , dest = "subcommand"
518+ )
519+ compare_parser = _parse_compare_args (subparsers )
520+
460521 # Check for passthrough args
461522 if "--" in argv :
462523 passthrough_index = argv .index ("--" )
@@ -466,7 +527,9 @@ def parse_args():
466527
467528 args = parser .parse_args (argv [1 :passthrough_index ])
468529 args = _infer_prompt_source (args )
530+ args = _check_model_args (parser , args )
469531 args = _check_conditional_args (parser , args )
532+ args = _check_compare_args (compare_parser , args )
470533 args = _update_load_manager_args (args )
471534
472535 return args , argv [passthrough_index + 1 :]
0 commit comments