Skip to content

Commit 55a597a

Browse files
committed
Add attribute output_names to the command line validate
1 parent ebecb67 commit 55a597a

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,12 @@ def get_parser_validate() -> ArgumentParser:
483483
parser.add_argument(
484484
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
485485
)
486+
parser.add_argument(
487+
"--outnames",
488+
help="This comma separated list defines the output names "
489+
"the onnx exporter should use.",
490+
default="",
491+
)
486492
return parser
487493

488494

@@ -542,6 +548,7 @@ def _cmd_validate(argv: List[Any]):
542548
repeat=args.repeat,
543549
warmup=args.warmup,
544550
inputs2=args.inputs2,
551+
output_names=args.outnames.strip().split(","),
545552
)
546553
print("")
547554
print("-- summary --")

onnx_diagnostic/torch_models/validate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def validate_model(
288288
repeat: int = 1,
289289
warmup: int = 0,
290290
inputs2: int = 1,
291+
output_names: Optional[List[str]] = None,
291292
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
292293
"""
293294
Validates a model.
@@ -338,6 +339,7 @@ def validate_model(
338339
:param inputs2: checks that the second set of inputs is reunning as well,
339340
this ensures that the model does support dynamism, the value is used
340341
as an increment to the first set of values (added to dimensions)
342+
:param output_names: output names the onnx exporter should use
341343
:return: two dictionaries, one with some metrics,
342344
another one with whatever the function produces
343345
@@ -631,6 +633,7 @@ def validate_model(
631633
optimization=optimization,
632634
do_run=do_run,
633635
dump_folder=dump_folder,
636+
output_names=output_names,
634637
)
635638
else:
636639
data["inputs_export"] = data["inputs"]
@@ -643,6 +646,7 @@ def validate_model(
643646
optimization=optimization,
644647
do_run=do_run,
645648
dump_folder=dump_folder,
649+
output_names=output_names,
646650
)
647651
summary.update(summary_export)
648652

@@ -868,6 +872,7 @@ def call_exporter(
868872
optimization: Optional[str] = None,
869873
do_run: bool = False,
870874
dump_folder: Optional[str] = None,
875+
output_names: Optional[List[str]] = None,
871876
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
872877
"""
873878
Calls an exporter on a model;
@@ -880,6 +885,7 @@ def call_exporter(
880885
:param optimization: optimization to do
881886
:param do_run: runs and compute discrepancies
882887
:param dump_folder: to dump additional information
888+
:param output_names: list of output names to use with the onnx exporter
883889
:return: two dictionaries, one with some metrics,
884890
another one with whatever the function produces
885891
"""
@@ -902,6 +908,7 @@ def call_exporter(
902908
quiet=quiet,
903909
verbose=verbose,
904910
optimization=optimization,
911+
output_names=output_names,
905912
)
906913
return summary, data
907914
if exporter == "custom" or exporter.startswith("custom"):
@@ -913,6 +920,7 @@ def call_exporter(
913920
verbose=verbose,
914921
optimization=optimization,
915922
dump_folder=dump_folder,
923+
output_names=output_names,
916924
)
917925
return summary, data
918926
if exporter == "modelbuilder":
@@ -923,6 +931,7 @@ def call_exporter(
923931
quiet=quiet,
924932
verbose=verbose,
925933
optimization=optimization,
934+
output_names=output_names,
926935
)
927936
return summary, data
928937
raise NotImplementedError(
@@ -1211,6 +1220,7 @@ def call_torch_export_onnx(
12111220
quiet: bool = False,
12121221
verbose: int = 0,
12131222
optimization: Optional[str] = None,
1223+
output_names: Optional[List[str]] = None,
12141224
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
12151225
"""
12161226
Exports a model into onnx.
@@ -1222,6 +1232,7 @@ def call_torch_export_onnx(
12221232
:param quiet: catch exception or not
12231233
:param verbose: verbosity
12241234
:param optimization: optimization to do
1235+
:param output_names: output names to use
12251236
:return: two dictionaries, one with some metrics,
12261237
another one with whatever the function produces
12271238
"""
@@ -1276,6 +1287,8 @@ def call_torch_export_onnx(
12761287
print("[call_torch_export_onnx] dynamo=False so...")
12771288
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
12781289
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1290+
if output_names:
1291+
export_export_kwargs["output_names"] = output_names
12791292
if opset:
12801293
export_export_kwargs["opset_version"] = opset
12811294
if verbose:
@@ -1346,6 +1359,7 @@ def call_torch_export_model_builder(
13461359
quiet: bool = False,
13471360
verbose: int = 0,
13481361
optimization: Optional[str] = None,
1362+
output_names: Optional[List[str]] = None,
13491363
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
13501364
"""
13511365
Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1356,6 +1370,7 @@ def call_torch_export_model_builder(
13561370
:param quiet: catch exception or not
13571371
:param verbose: verbosity
13581372
:param optimization: optimization to do
1373+
:param output_names: list of output names to use
13591374
:return: two dictionaries, one with some metrics,
13601375
another one with whatever the function produces
13611376
"""
@@ -1369,6 +1384,9 @@ def call_torch_export_model_builder(
13691384
provider = data.get("model_device", "cpu")
13701385
dump_folder = data.get("model_dump_folder", "")
13711386
assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
1387+
assert (
1388+
not output_names
1389+
), f"output_names not empty, not supported yet, output_names={output_names}"
13721390
cache_dir = os.path.join(dump_folder, "cache_mb")
13731391
if not os.path.exists(cache_dir):
13741392
os.makedirs(cache_dir)
@@ -1408,6 +1426,7 @@ def call_torch_export_custom(
14081426
verbose: int = 0,
14091427
optimization: Optional[str] = None,
14101428
dump_folder: Optional[str] = None,
1429+
output_names: Optional[List[str]] = None,
14111430
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
14121431
"""
14131432
Exports a model into onnx.
@@ -1420,6 +1439,7 @@ def call_torch_export_custom(
14201439
:param verbose: verbosity
14211440
:param optimization: optimization to do
14221441
:param dump_folder: to store additional information
1442+
:param output_names: list of output names to use
14231443
:return: two dictionaries, one with some metrics,
14241444
another one with whatever the function produces
14251445
"""
@@ -1504,6 +1524,8 @@ def call_torch_export_custom(
15041524
)
15051525
if opset:
15061526
kws["target_opset"] = opset
1527+
if output_names:
1528+
kws["output_names"] = output_names
15071529

15081530
epo, opt_stats = _quiet_or_not_quiet(
15091531
quiet,

0 commit comments

Comments
 (0)