@@ -288,6 +288,7 @@ def validate_model(
288288 repeat : int = 1 ,
289289 warmup : int = 0 ,
290290 inputs2 : int = 1 ,
291+ output_names : Optional [List [str ]] = None ,
291292) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
292293 """
293294 Validates a model.
@@ -338,6 +339,7 @@ def validate_model(
338339 :param inputs2: checks that the second set of inputs is reunning as well,
339340 this ensures that the model does support dynamism, the value is used
340341 as an increment to the first set of values (added to dimensions)
342+ :param output_names: output names the onnx exporter should use
341343 :return: two dictionaries, one with some metrics,
342344 another one with whatever the function produces
343345
@@ -433,6 +435,7 @@ def validate_model(
433435 )
434436 print (f"[validate_model] exporter={ exporter !r} , optimization={ optimization !r} " )
435437 print (f"[validate_model] dump_folder={ dump_folder !r} " )
438+ print (f"[validate_model] output_names={ output_names } " )
436439 summary ["model_id" ] = model_id
437440 summary ["model_subfolder" ] = subfolder or ""
438441
@@ -631,6 +634,7 @@ def validate_model(
631634 optimization = optimization ,
632635 do_run = do_run ,
633636 dump_folder = dump_folder ,
637+ output_names = output_names ,
634638 )
635639 else :
636640 data ["inputs_export" ] = data ["inputs" ]
@@ -643,6 +647,7 @@ def validate_model(
643647 optimization = optimization ,
644648 do_run = do_run ,
645649 dump_folder = dump_folder ,
650+ output_names = output_names ,
646651 )
647652 summary .update (summary_export )
648653
@@ -868,6 +873,7 @@ def call_exporter(
868873 optimization : Optional [str ] = None ,
869874 do_run : bool = False ,
870875 dump_folder : Optional [str ] = None ,
876+ output_names : Optional [List [str ]] = None ,
871877) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
872878 """
873879 Calls an exporter on a model;
@@ -880,6 +886,7 @@ def call_exporter(
880886 :param optimization: optimization to do
881887 :param do_run: runs and compute discrepancies
882888 :param dump_folder: to dump additional information
889+ :param output_names: list of output names to use with the onnx exporter
883890 :return: two dictionaries, one with some metrics,
884891 another one with whatever the function produces
885892 """
@@ -902,6 +909,7 @@ def call_exporter(
902909 quiet = quiet ,
903910 verbose = verbose ,
904911 optimization = optimization ,
912+ output_names = output_names ,
905913 )
906914 return summary , data
907915 if exporter == "custom" or exporter .startswith ("custom" ):
@@ -913,6 +921,7 @@ def call_exporter(
913921 verbose = verbose ,
914922 optimization = optimization ,
915923 dump_folder = dump_folder ,
924+ output_names = output_names ,
916925 )
917926 return summary , data
918927 if exporter == "modelbuilder" :
@@ -923,6 +932,7 @@ def call_exporter(
923932 quiet = quiet ,
924933 verbose = verbose ,
925934 optimization = optimization ,
935+ output_names = output_names ,
926936 )
927937 return summary , data
928938 raise NotImplementedError (
@@ -1211,6 +1221,7 @@ def call_torch_export_onnx(
12111221 quiet : bool = False ,
12121222 verbose : int = 0 ,
12131223 optimization : Optional [str ] = None ,
1224+ output_names : Optional [List [str ]] = None ,
12141225) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
12151226 """
12161227 Exports a model into onnx.
@@ -1222,6 +1233,7 @@ def call_torch_export_onnx(
12221233 :param quiet: catch exception or not
12231234 :param verbose: verbosity
12241235 :param optimization: optimization to do
1236+ :param output_names: output names to use
12251237 :return: two dictionaries, one with some metrics,
12261238 another one with whatever the function produces
12271239 """
@@ -1276,6 +1288,8 @@ def call_torch_export_onnx(
12761288 print ("[call_torch_export_onnx] dynamo=False so..." )
12771289 print (f"[call_torch_export_onnx] args={ string_type (args , with_shape = True )} " )
12781290 print (f"[call_torch_export_onnx] kwargs={ string_type (kwargs , with_shape = True )} " )
1291+ if output_names :
1292+ export_export_kwargs ["output_names" ] = output_names
12791293 if opset :
12801294 export_export_kwargs ["opset_version" ] = opset
12811295 if verbose :
@@ -1346,6 +1360,7 @@ def call_torch_export_model_builder(
13461360 quiet : bool = False ,
13471361 verbose : int = 0 ,
13481362 optimization : Optional [str ] = None ,
1363+ output_names : Optional [List [str ]] = None ,
13491364) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
13501365 """
13511366 Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1356,6 +1371,7 @@ def call_torch_export_model_builder(
13561371 :param quiet: catch exception or not
13571372 :param verbose: verbosity
13581373 :param optimization: optimization to do
1374+ :param output_names: list of output names to use
13591375 :return: two dictionaries, one with some metrics,
13601376 another one with whatever the function produces
13611377 """
@@ -1369,6 +1385,9 @@ def call_torch_export_model_builder(
13691385 provider = data .get ("model_device" , "cpu" )
13701386 dump_folder = data .get ("model_dump_folder" , "" )
13711387 assert dump_folder , "dump_folder cannot be empty with ModelBuilder"
1388+ assert (
1389+ not output_names
1390+ ), f"output_names not empty, not supported yet, output_names={ output_names } "
13721391 cache_dir = os .path .join (dump_folder , "cache_mb" )
13731392 if not os .path .exists (cache_dir ):
13741393 os .makedirs (cache_dir )
@@ -1408,6 +1427,7 @@ def call_torch_export_custom(
14081427 verbose : int = 0 ,
14091428 optimization : Optional [str ] = None ,
14101429 dump_folder : Optional [str ] = None ,
1430+ output_names : Optional [List [str ]] = None ,
14111431) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
14121432 """
14131433 Exports a model into onnx.
@@ -1420,6 +1440,7 @@ def call_torch_export_custom(
14201440 :param verbose: verbosity
14211441 :param optimization: optimization to do
14221442 :param dump_folder: to store additional information
1443+ :param output_names: list of output names to use
14231444 :return: two dictionaries, one with some metrics,
14241445 another one with whatever the function produces
14251446 """
@@ -1504,6 +1525,8 @@ def call_torch_export_custom(
15041525 )
15051526 if opset :
15061527 kws ["target_opset" ] = opset
1528+ if output_names :
1529+ kws ["output_names" ] = output_names
15071530
15081531 epo , opt_stats = _quiet_or_not_quiet (
15091532 quiet ,
0 commit comments