@@ -557,7 +557,7 @@ def get_parser_validate() -> ArgumentParser:
557557 "--quiet-input-sets" ,
558558 default = "" ,
559559 help = "Avoids raising an exception when an input sets does not work with "
560- "the exported model, example : --quiet-input-sets=inputs,inputs22" ,
560+ "the exported model. \n Example : --quiet-input-sets=inputs,inputs22" ,
561561 )
562562 return parser
563563
@@ -631,6 +631,94 @@ def _cmd_validate(argv: List[Any]):
631631 print (f":{ k } ,{ v } ;" )
632632
633633
634+ def _cmd_export_sample (argv : List [Any ]):
635+ from .helpers import string_type
636+ from .torch_models .validate import get_inputs_for_task , _make_folder_name
637+ from .torch_models .code_sample import code_sample
638+ from .tasks import supported_tasks
639+
640+ parser = get_parser_validate ()
641+ args = parser .parse_args (argv [1 :])
642+ if not args .task and not args .mid :
643+ print ("-- list of supported tasks:" )
644+ print ("\n " .join (supported_tasks ()))
645+ elif not args .mid :
646+ data = get_inputs_for_task (args .task )
647+ if args .verbose :
648+ print (f"task: { args .task } " )
649+ max_length = max (len (k ) for k in data ["inputs" ]) + 1
650+ print ("-- inputs" )
651+ for k , v in data ["inputs" ].items ():
652+ print (f" + { k .ljust (max_length )} : { string_type (v , with_shape = True )} " )
653+ print ("-- dynamic_shapes" )
654+ for k , v in data ["dynamic_shapes" ].items ():
655+ print (f" + { k .ljust (max_length )} : { string_type (v )} " )
656+ else :
657+ # Let's skip any invalid combination if known to be unsupported
658+ if (
659+ "onnx" not in (args .export or "" )
660+ and "custom" not in (args .export or "" )
661+ and (args .opt or "" )
662+ ):
663+ print (f"code-sample - unsupported args: export={ args .export !r} , opt={ args .opt !r} " )
664+ return
665+ patch_dict = args .patch if isinstance (args .patch , dict ) else {"patch" : args .patch }
666+ code = code_sample (
667+ model_id = args .mid ,
668+ task = args .task ,
669+ do_run = args .run ,
670+ verbose = args .verbose ,
671+ quiet = args .quiet ,
672+ same_as_pretrained = args .same_as_trained ,
673+ use_pretrained = args .trained ,
674+ dtype = args .dtype ,
675+ device = args .device ,
676+ patch = patch_dict ,
677+ rewrite = args .rewrite and patch_dict .get ("patch" , True ),
678+ stop_if_static = args .stop_if_static ,
679+ optimization = args .opt ,
680+ exporter = args .export ,
681+ dump_folder = args .dump_folder ,
682+ drop_inputs = None if not args .drop else args .drop .split ("," ),
683+ input_options = args .iop ,
684+ model_options = args .mop ,
685+ subfolder = args .subfolder ,
686+ opset = args .opset ,
687+ runtime = args .runtime ,
688+ output_names = (
689+ None if len (args .outnames .strip ()) < 2 else args .outnames .strip ().split ("," )
690+ ),
691+ )
692+ if args .dump_folder :
693+ os .makedirs (args .dump_folder , exist_ok = True )
694+ name = (
695+ _make_folder_name (
696+ model_id = args .model_id ,
697+ exporter = args .exporter ,
698+ optimization = args .optimization ,
699+ dtype = args .dtype ,
700+ device = args .device ,
701+ subfolder = args .subfolder ,
702+ opset = args .opset ,
703+ drop_inputs = None if not args .drop else args .drop .split ("," ),
704+ same_as_pretrained = args .same_as_pretrained ,
705+ use_pretrained = args .use_pretrained ,
706+ task = args .task ,
707+ ).replace ("/" , "-" )
708+ + ".py"
709+ )
710+ fullname = os .path .join (args .dump_folder , name )
711+ if args .verbose :
712+ print (f"-- prints code in { fullname !r} " )
713+ print ("--" )
714+ with open (fullname , "w" ) as f :
715+ f .write (code )
716+ if args .verbose :
717+ print ("-- done" )
718+ else :
719+ print (code )
720+
721+
634722def get_parser_stats () -> ArgumentParser :
635723 parser = ArgumentParser (
636724 prog = "stats" ,
@@ -960,14 +1048,15 @@ def get_main_parser() -> ArgumentParser:
9601048 Type 'python -m onnx_diagnostic <cmd> --help'
9611049 to get help for a specific command.
9621050
963- agg - aggregates statistics from multiple files
964- config - prints a configuration for a model id
965- find - find node consuming or producing a result
966- lighten - makes an onnx model lighter by removing the weights,
967- print - prints the model on standard output
968- stats - produces statistics on a model
969- unlighten - restores an onnx model produces by the previous experiment
970- validate - validate a model
1051+ agg - aggregates statistics from multiple files
1052+ config - prints a configuration for a model id
1053+ exportsample - produces a code to export a model
1054+ find - find node consuming or producing a result
1055+ lighten - makes an onnx model lighter by removing the weights,
1056+ print - prints the model on standard output
1057+ stats - produces statistics on a model
1058+ unlighten - restores an onnx model produces by the previous experiment
1059+ validate - validate a model
9711060 """
9721061 ),
9731062 )
@@ -976,6 +1065,7 @@ def get_main_parser() -> ArgumentParser:
9761065 choices = [
9771066 "agg" ,
9781067 "config" ,
1068+ "exportsample" ,
9791069 "find" ,
9801070 "lighten" ,
9811071 "print" ,
@@ -998,6 +1088,7 @@ def main(argv: Optional[List[Any]] = None):
9981088 validate = _cmd_validate ,
9991089 stats = _cmd_stats ,
10001090 agg = _cmd_agg ,
1091+ exportsample = _cmd_export_sample ,
10011092 )
10021093
10031094 if argv is None :
@@ -1020,6 +1111,7 @@ def main(argv: Optional[List[Any]] = None):
10201111 validate = get_parser_validate ,
10211112 stats = get_parser_stats ,
10221113 agg = get_parser_agg ,
1114+ exportsample = get_parser_validate ,
10231115 )
10241116 cmd = argv [0 ]
10251117 if cmd not in parsers :
0 commit comments