Skip to content

Commit d705153

Browse files
authored
Add attribute output_names to the command line validate (#186)
* Add attribute output_names to the command line validate * changres * fix output names
1 parent d5e5fbc commit d705153

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.5
55
+++++
66

7+
* :pr:`186`: add parameter --output_names to command line validate to change the output names of the onnx exported model
78
* :pr:`185`: remove the use of _seen_tokens in DynamicCache (removed in transformers>4.53),
89
updates dummpy inputs for feature-extraction
910
* :pr:`184`: implements side-by-side

onnx_diagnostic/_command_lines_parser.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,12 @@ def get_parser_validate() -> ArgumentParser:
483483
parser.add_argument(
484484
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
485485
)
486+
parser.add_argument(
487+
"--outnames",
488+
help="This comma separated list defines the output names "
489+
"the onnx exporter should use.",
490+
default="",
491+
)
486492
return parser
487493

488494

@@ -542,6 +548,9 @@ def _cmd_validate(argv: List[Any]):
542548
repeat=args.repeat,
543549
warmup=args.warmup,
544550
inputs2=args.inputs2,
551+
output_names=(
552+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
553+
),
545554
)
546555
print("")
547556
print("-- summary --")

onnx_diagnostic/torch_models/validate.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def validate_model(
288288
repeat: int = 1,
289289
warmup: int = 0,
290290
inputs2: int = 1,
291+
output_names: Optional[List[str]] = None,
291292
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
292293
"""
293294
Validates a model.
@@ -338,6 +339,7 @@ def validate_model(
338339
:param inputs2: checks that the second set of inputs is reunning as well,
339340
this ensures that the model does support dynamism, the value is used
340341
as an increment to the first set of values (added to dimensions)
342+
:param output_names: output names the onnx exporter should use
341343
:return: two dictionaries, one with some metrics,
342344
another one with whatever the function produces
343345
@@ -433,6 +435,7 @@ def validate_model(
433435
)
434436
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
435437
print(f"[validate_model] dump_folder={dump_folder!r}")
438+
print(f"[validate_model] output_names={output_names}")
436439
summary["model_id"] = model_id
437440
summary["model_subfolder"] = subfolder or ""
438441

@@ -631,6 +634,7 @@ def validate_model(
631634
optimization=optimization,
632635
do_run=do_run,
633636
dump_folder=dump_folder,
637+
output_names=output_names,
634638
)
635639
else:
636640
data["inputs_export"] = data["inputs"]
@@ -643,6 +647,7 @@ def validate_model(
643647
optimization=optimization,
644648
do_run=do_run,
645649
dump_folder=dump_folder,
650+
output_names=output_names,
646651
)
647652
summary.update(summary_export)
648653

@@ -868,6 +873,7 @@ def call_exporter(
868873
optimization: Optional[str] = None,
869874
do_run: bool = False,
870875
dump_folder: Optional[str] = None,
876+
output_names: Optional[List[str]] = None,
871877
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
872878
"""
873879
Calls an exporter on a model;
@@ -880,6 +886,7 @@ def call_exporter(
880886
:param optimization: optimization to do
881887
:param do_run: runs and compute discrepancies
882888
:param dump_folder: to dump additional information
889+
:param output_names: list of output names to use with the onnx exporter
883890
:return: two dictionaries, one with some metrics,
884891
another one with whatever the function produces
885892
"""
@@ -902,6 +909,7 @@ def call_exporter(
902909
quiet=quiet,
903910
verbose=verbose,
904911
optimization=optimization,
912+
output_names=output_names,
905913
)
906914
return summary, data
907915
if exporter == "custom" or exporter.startswith("custom"):
@@ -913,6 +921,7 @@ def call_exporter(
913921
verbose=verbose,
914922
optimization=optimization,
915923
dump_folder=dump_folder,
924+
output_names=output_names,
916925
)
917926
return summary, data
918927
if exporter == "modelbuilder":
@@ -923,6 +932,7 @@ def call_exporter(
923932
quiet=quiet,
924933
verbose=verbose,
925934
optimization=optimization,
935+
output_names=output_names,
926936
)
927937
return summary, data
928938
raise NotImplementedError(
@@ -1211,6 +1221,7 @@ def call_torch_export_onnx(
12111221
quiet: bool = False,
12121222
verbose: int = 0,
12131223
optimization: Optional[str] = None,
1224+
output_names: Optional[List[str]] = None,
12141225
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
12151226
"""
12161227
Exports a model into onnx.
@@ -1222,6 +1233,7 @@ def call_torch_export_onnx(
12221233
:param quiet: catch exception or not
12231234
:param verbose: verbosity
12241235
:param optimization: optimization to do
1236+
:param output_names: output names to use
12251237
:return: two dictionaries, one with some metrics,
12261238
another one with whatever the function produces
12271239
"""
@@ -1276,6 +1288,8 @@ def call_torch_export_onnx(
12761288
print("[call_torch_export_onnx] dynamo=False so...")
12771289
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
12781290
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1291+
if output_names:
1292+
export_export_kwargs["output_names"] = output_names
12791293
if opset:
12801294
export_export_kwargs["opset_version"] = opset
12811295
if verbose:
@@ -1346,6 +1360,7 @@ def call_torch_export_model_builder(
13461360
quiet: bool = False,
13471361
verbose: int = 0,
13481362
optimization: Optional[str] = None,
1363+
output_names: Optional[List[str]] = None,
13491364
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
13501365
"""
13511366
Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1356,6 +1371,7 @@ def call_torch_export_model_builder(
13561371
:param quiet: catch exception or not
13571372
:param verbose: verbosity
13581373
:param optimization: optimization to do
1374+
:param output_names: list of output names to use
13591375
:return: two dictionaries, one with some metrics,
13601376
another one with whatever the function produces
13611377
"""
@@ -1369,6 +1385,9 @@ def call_torch_export_model_builder(
13691385
provider = data.get("model_device", "cpu")
13701386
dump_folder = data.get("model_dump_folder", "")
13711387
assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
1388+
assert (
1389+
not output_names
1390+
), f"output_names not empty, not supported yet, output_names={output_names}"
13721391
cache_dir = os.path.join(dump_folder, "cache_mb")
13731392
if not os.path.exists(cache_dir):
13741393
os.makedirs(cache_dir)
@@ -1408,6 +1427,7 @@ def call_torch_export_custom(
14081427
verbose: int = 0,
14091428
optimization: Optional[str] = None,
14101429
dump_folder: Optional[str] = None,
1430+
output_names: Optional[List[str]] = None,
14111431
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
14121432
"""
14131433
Exports a model into onnx.
@@ -1420,6 +1440,7 @@ def call_torch_export_custom(
14201440
:param verbose: verbosity
14211441
:param optimization: optimization to do
14221442
:param dump_folder: to store additional information
1443+
:param output_names: list of output names to use
14231444
:return: two dictionaries, one with some metrics,
14241445
another one with whatever the function produces
14251446
"""
@@ -1504,6 +1525,8 @@ def call_torch_export_custom(
15041525
)
15051526
if opset:
15061527
kws["target_opset"] = opset
1528+
if output_names:
1529+
kws["output_names"] = output_names
15071530

15081531
epo, opt_stats = _quiet_or_not_quiet(
15091532
quiet,

0 commit comments

Comments
 (0)