@@ -371,30 +371,34 @@ def __call__(self, parser, namespace, values, option_string=None):
371371 setattr (namespace , self .dest , d )
372372
373373
374- def get_parser_validate () -> ArgumentParser :
374+ def get_parser_validate (name : str = "validate" ) -> ArgumentParser :
375375 parser = ArgumentParser (
376- prog = "validate" ,
376+ prog = name ,
377377 description = textwrap .dedent (
378378 """
379- Prints out dummy inputs for a particular task or a model id.
380- If both mid and task are empty, the command line displays the list
381- of supported tasks.
379+ Validates a model for a particular task given the model id.
380+ It exports the model and then validates it by computing the discrepancies
381+ on different input sets.
382+ """
383+ if name == "validate"
384+ else """
385+ Creates a script to export a model for a particular task given the model id.
382386 """
383387 ),
384388 epilog = textwrap .dedent (
385- """
389+ f """
386390 If the model id is specified, one untrained version of it is instantiated.
387391 Examples:
388392
389- python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
393+ python -m onnx_diagnostic { name } -m microsoft/Phi-4-mini-reasoning \\
390394 --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
391395 --dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
392396
393- python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
397+ python -m onnx_diagnostic { name } -m microsoft/Phi-4-mini-reasoning \\
394398 --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
395399 --dtype float16 --device cuda --patch --export custom --opt default
396400
397- python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
401+ python -m onnx_diagnostic { name } -m microsoft/Phi-4-mini-reasoning \\
398402 --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
399403 --dtype float16 --device cuda --export modelbuilder
400404
@@ -405,12 +409,12 @@ def get_parser_validate() -> ArgumentParser:
405409 The behaviour may be modified compare the original configuration,
406410 the following argument can be rope_scaling to dynamic:
407411
408- --mop \" rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\" "
412+ --mop \" rope_scaling={{ 'rope_type': 'dynamic', 'factor': 10.0} }\" "
409413
410414 You can profile the command line by running:
411415
412- pyinstrument -m onnx_diagnostic validate ...
413- pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
416+ pyinstrument -m onnx_diagnostic { name } ...
417+ pyinstrument -r html -o profile.html -m onnx_diagnostic { name } ...
414418 """
415419 ),
416420 formatter_class = RawTextHelpFormatter ,
@@ -460,19 +464,19 @@ def get_parser_validate() -> ArgumentParser:
460464 "--same-as-trained" ,
461465 default = False ,
462466 action = BooleanOptionalAction ,
463- help = "Validates a model identical to the trained model but not trained." ,
467+ help = "Validates or exports a model identical to the trained model but not trained." ,
464468 )
465469 parser .add_argument (
466470 "--trained" ,
467471 default = False ,
468472 action = BooleanOptionalAction ,
469- help = "Validates the trained model (requires downloading)." ,
473+ help = "Validates or exports the trained model (requires downloading)." ,
470474 )
471475 parser .add_argument (
472476 "--inputs2" ,
473477 default = 1 ,
474478 type = int ,
475- help = "Validates the model on a second set of inputs\n "
479+ help = "Validates or exports the model on a second set of inputs\n "
476480 "to check the exported model supports dynamism. The values is used "
477481 "as an increment to the first set of inputs. A high value may trick "
478482 "a different behavior in the model and missed by the exporter." ,
@@ -504,13 +508,14 @@ def get_parser_validate() -> ArgumentParser:
504508 "--subfolder" ,
505509 help = "Subfolder where to find the model and the configuration." ,
506510 )
507- parser .add_argument (
508- "--ortfusiontype" ,
509- required = False ,
510- help = "Applies onnxruntime fusion, this parameter should contain the\n "
511- "model type or multiple values separated by `|`. `ALL` can be used\n "
512- "to run them all." ,
513- )
511+ if name == "validate" :
512+ parser .add_argument (
513+ "--ortfusiontype" ,
514+ required = False ,
515+ help = "Applies onnxruntime fusion, this parameter should contain the\n "
516+ "model type or multiple values separated by `|`. `ALL` can be used\n "
517+ "to run them all." ,
518+ )
514519 parser .add_argument ("-v" , "--verbose" , default = 0 , type = int , help = "verbosity" )
515520 parser .add_argument ("--dtype" , help = "Changes dtype if necessary." )
516521 parser .add_argument ("--device" , help = "Changes the device if necessary." )
@@ -532,33 +537,38 @@ def get_parser_validate() -> ArgumentParser:
532537 "--mop \" rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\" " ,
533538 action = _ParseDict ,
534539 )
535- parser .add_argument (
536- "--repeat" ,
537- default = 1 ,
538- type = int ,
539- help = "number of times to run the model to measures inference time" ,
540- )
541- parser .add_argument (
542- "--warmup" , default = 0 , type = int , help = "number of times to run the model to do warmup"
543- )
540+ if name == "validate" :
541+ parser .add_argument (
542+ "--repeat" ,
543+ default = 1 ,
544+ type = int ,
545+ help = "number of times to run the model to measures inference time" ,
546+ )
547+ parser .add_argument (
548+ "--warmup" ,
549+ default = 0 ,
550+ type = int ,
551+ help = "number of times to run the model to do warmup" ,
552+ )
544553 parser .add_argument (
545554 "--outnames" ,
546555 help = "This comma separated list defines the output names "
547556 "the onnx exporter should use." ,
548557 default = "" ,
549558 )
550- parser .add_argument (
551- "--ort-logs" ,
552- default = False ,
553- action = BooleanOptionalAction ,
554- help = "Enables onnxruntime logging when the session is created" ,
555- )
556- parser .add_argument (
557- "--quiet-input-sets" ,
558- default = "" ,
559- help = "Avoids raising an exception when an input sets does not work with "
560- "the exported model.\n Example: --quiet-input-sets=inputs,inputs22" ,
561- )
559+ if name == "validate" :
560+ parser .add_argument (
561+ "--ort-logs" ,
562+ default = False ,
563+ action = BooleanOptionalAction ,
564+ help = "Enables onnxruntime logging when the session is created" ,
565+ )
566+ parser .add_argument (
567+ "--quiet-input-sets" ,
568+ default = "" ,
569+ help = "Avoids raising an exception when an input sets does not work with "
570+ "the exported model.\n Example: --quiet-input-sets=inputs,inputs22" ,
571+ )
562572 return parser
563573
564574
@@ -637,7 +647,7 @@ def _cmd_export_sample(argv: List[Any]):
637647 from .torch_models .code_sample import code_sample
638648 from .tasks import supported_tasks
639649
640- parser = get_parser_validate ()
650+ parser = get_parser_validate ("exportsample" )
641651 args = parser .parse_args (argv [1 :])
642652 if not args .task and not args .mid :
643653 print ("-- list of supported tasks:" )
@@ -693,16 +703,16 @@ def _cmd_export_sample(argv: List[Any]):
693703 os .makedirs (args .dump_folder , exist_ok = True )
694704 name = (
695705 _make_folder_name (
696- model_id = args .model_id ,
697- exporter = args .exporter ,
698- optimization = args .optimization ,
706+ model_id = args .mid ,
707+ exporter = args .export ,
708+ optimization = args .opt ,
699709 dtype = args .dtype ,
700710 device = args .device ,
701711 subfolder = args .subfolder ,
702712 opset = args .opset ,
703713 drop_inputs = None if not args .drop else args .drop .split ("," ),
704- same_as_pretrained = args .same_as_pretrained ,
705- use_pretrained = args .use_pretrained ,
714+ same_as_pretrained = args .same_as_trained ,
715+ use_pretrained = args .trained ,
706716 task = args .task ,
707717 ).replace ("/" , "-" )
708718 + ".py"
@@ -1111,7 +1121,7 @@ def main(argv: Optional[List[Any]] = None):
11111121 validate = get_parser_validate ,
11121122 stats = get_parser_stats ,
11131123 agg = get_parser_agg ,
1114- exportsample = get_parser_validate ,
1124+ exportsample = lambda : get_parser_validate ( "exportsample" ), # type: ignore[operator]
11151125 )
11161126 cmd = argv [0 ]
11171127 if cmd not in parsers :
0 commit comments