@@ -288,6 +288,7 @@ def validate_model(
288288 repeat : int = 1 ,
289289 warmup : int = 0 ,
290290 inputs2 : int = 1 ,
291+ output_names : Optional [List [str ]] = None ,
291292) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
292293 """
293294 Validates a model.
@@ -338,6 +339,7 @@ def validate_model(
338339 :param inputs2: checks that the second set of inputs is reunning as well,
339340 this ensures that the model does support dynamism, the value is used
340341 as an increment to the first set of values (added to dimensions)
342+ :param output_names: output names the onnx exporter should use
341343 :return: two dictionaries, one with some metrics,
342344 another one with whatever the function produces
343345
@@ -631,6 +633,7 @@ def validate_model(
631633 optimization = optimization ,
632634 do_run = do_run ,
633635 dump_folder = dump_folder ,
636+ output_names = output_names ,
634637 )
635638 else :
636639 data ["inputs_export" ] = data ["inputs" ]
@@ -643,6 +646,7 @@ def validate_model(
643646 optimization = optimization ,
644647 do_run = do_run ,
645648 dump_folder = dump_folder ,
649+ output_names = output_names ,
646650 )
647651 summary .update (summary_export )
648652
@@ -868,6 +872,7 @@ def call_exporter(
868872 optimization : Optional [str ] = None ,
869873 do_run : bool = False ,
870874 dump_folder : Optional [str ] = None ,
875+ output_names : Optional [List [str ]] = None ,
871876) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
872877 """
873878 Calls an exporter on a model;
@@ -880,6 +885,7 @@ def call_exporter(
880885 :param optimization: optimization to do
881886 :param do_run: runs and compute discrepancies
882887 :param dump_folder: to dump additional information
888+ :param output_names: list of output names to use with the onnx exporter
883889 :return: two dictionaries, one with some metrics,
884890 another one with whatever the function produces
885891 """
@@ -902,6 +908,7 @@ def call_exporter(
902908 quiet = quiet ,
903909 verbose = verbose ,
904910 optimization = optimization ,
911+ output_names = output_names ,
905912 )
906913 return summary , data
907914 if exporter == "custom" or exporter .startswith ("custom" ):
@@ -913,6 +920,7 @@ def call_exporter(
913920 verbose = verbose ,
914921 optimization = optimization ,
915922 dump_folder = dump_folder ,
923+ output_names = output_names ,
916924 )
917925 return summary , data
918926 if exporter == "modelbuilder" :
@@ -923,6 +931,7 @@ def call_exporter(
923931 quiet = quiet ,
924932 verbose = verbose ,
925933 optimization = optimization ,
934+ output_names = output_names ,
926935 )
927936 return summary , data
928937 raise NotImplementedError (
@@ -1211,6 +1220,7 @@ def call_torch_export_onnx(
12111220 quiet : bool = False ,
12121221 verbose : int = 0 ,
12131222 optimization : Optional [str ] = None ,
1223+ output_names : Optional [List [str ]] = None ,
12141224) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
12151225 """
12161226 Exports a model into onnx.
@@ -1222,6 +1232,7 @@ def call_torch_export_onnx(
12221232 :param quiet: catch exception or not
12231233 :param verbose: verbosity
12241234 :param optimization: optimization to do
1235+ :param output_names: output names to use
12251236 :return: two dictionaries, one with some metrics,
12261237 another one with whatever the function produces
12271238 """
@@ -1276,6 +1287,8 @@ def call_torch_export_onnx(
12761287 print ("[call_torch_export_onnx] dynamo=False so..." )
12771288 print (f"[call_torch_export_onnx] args={ string_type (args , with_shape = True )} " )
12781289 print (f"[call_torch_export_onnx] kwargs={ string_type (kwargs , with_shape = True )} " )
1290+ if output_names :
1291+ export_export_kwargs ["output_names" ] = output_names
12791292 if opset :
12801293 export_export_kwargs ["opset_version" ] = opset
12811294 if verbose :
@@ -1346,6 +1359,7 @@ def call_torch_export_model_builder(
13461359 quiet : bool = False ,
13471360 verbose : int = 0 ,
13481361 optimization : Optional [str ] = None ,
1362+ output_names : Optional [List [str ]] = None ,
13491363) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
13501364 """
13511365 Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1356,6 +1370,7 @@ def call_torch_export_model_builder(
13561370 :param quiet: catch exception or not
13571371 :param verbose: verbosity
13581372 :param optimization: optimization to do
1373+ :param output_names: list of output names to use
13591374 :return: two dictionaries, one with some metrics,
13601375 another one with whatever the function produces
13611376 """
@@ -1369,6 +1384,9 @@ def call_torch_export_model_builder(
13691384 provider = data .get ("model_device" , "cpu" )
13701385 dump_folder = data .get ("model_dump_folder" , "" )
13711386 assert dump_folder , "dump_folder cannot be empty with ModelBuilder"
1387+ assert (
1388+ not output_names
1389+ ), f"output_names not empty, not supported yet, output_names={ output_names } "
13721390 cache_dir = os .path .join (dump_folder , "cache_mb" )
13731391 if not os .path .exists (cache_dir ):
13741392 os .makedirs (cache_dir )
@@ -1408,6 +1426,7 @@ def call_torch_export_custom(
14081426 verbose : int = 0 ,
14091427 optimization : Optional [str ] = None ,
14101428 dump_folder : Optional [str ] = None ,
1429+ output_names : Optional [List [str ]] = None ,
14111430) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
14121431 """
14131432 Exports a model into onnx.
@@ -1420,6 +1439,7 @@ def call_torch_export_custom(
14201439 :param verbose: verbosity
14211440 :param optimization: optimization to do
14221441 :param dump_folder: to store additional information
1442+ :param output_names: list of output names to use
14231443 :return: two dictionaries, one with some metrics,
14241444 another one with whatever the function produces
14251445 """
@@ -1504,6 +1524,8 @@ def call_torch_export_custom(
15041524 )
15051525 if opset :
15061526 kws ["target_opset" ] = opset
1527+ if output_names :
1528+ kws ["output_names" ] = output_names
15071529
15081530 epo , opt_stats = _quiet_or_not_quiet (
15091531 quiet ,
0 commit comments