Skip to content

Commit 31b6337

Browse files
committed
ortfusion in command line
1 parent 18aead7 commit 31b6337

File tree

2 files changed

+119
-27
lines changed

2 files changed

+119
-27
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ def get_parser_validate() -> ArgumentParser:
287287
help="drops the following inputs names, it should be a list "
288288
"with comma separated values",
289289
)
290+
parser.add_argument(
291+
"--ortfusiontype",
292+
required=False,
293+
help="applies onnxruntime fusion, this parameter should contain the "
294+
"model type or multiple values separated by |",
295+
)
290296
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
291297
parser.add_argument("--dtype", help="changes dtype if necessary")
292298
parser.add_argument("--device", help="changes the device if necessary")
@@ -338,6 +344,7 @@ def _cmd_validate(argv: List[Any]):
338344
exporter=args.export,
339345
dump_folder=args.dump_folder,
340346
drop_inputs=None if not args.drop else args.drop.split(","),
347+
ortfusiontype=args.ortfusiontype,
341348
)
342349
print("")
343350
print("-- summary --")

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 112 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def validate_model(
197197
stop_if_static: int = 1,
198198
dump_folder: Optional[str] = None,
199199
drop_inputs: Optional[List[str]] = None,
200+
ortfusiontype: Optional[str] = None,
200201
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
201202
"""
202203
Validates a model.
@@ -222,11 +223,33 @@ def validate_model(
222223
see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
223224
:param dump_folder: dumps everything in a subfolder of this one
224225
:param drop_inputs: drops this list of inputs (given their names)
226+
:param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
227+
it accepts multiple values separated by ``|``,
228+
see :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`
225229
:return: two dictionaries, one with some metrics,
226230
another one with whatever the function produces
227231
"""
228232
assert not trained, f"trained={trained} not supported yet"
229233
summary = version_summary()
234+
235+
summary.update(
236+
dict(
237+
version_model_id=model_id,
238+
version_do_run=str(do_run),
239+
version_dtype=str(dtype or ""),
240+
version_device=str(device or ""),
241+
version_trained=str(trained),
242+
version_optimization=optimization or "",
243+
version_quiet=str(quiet),
244+
version_patch=str(patch),
245+
version_dump_folder=dump_folder or "",
246+
version_drop_inputs=str(list(drop_inputs or "")),
247+
version_ortfusiontype=ortfusiontype or "",
248+
version_stop_if_static=str(stop_if_static),
249+
version_exporter=exporter,
250+
)
251+
)
252+
230253
folder_name = None
231254
if dump_folder:
232255
folder_name = _make_folder_name(
@@ -456,15 +479,66 @@ def validate_model(
456479
if verbose:
457480
print("[validate_model] done (dump)")
458481

459-
if exporter and exporter.startswith(("onnx-", "custom-")) and do_run:
460-
summary_valid, data = validate_onnx_model(
461-
data=data,
462-
quiet=quiet,
463-
verbose=verbose,
464-
optimization=optimization,
465-
)
482+
if not exporter or not exporter.startswith(("onnx-", "custom-")):
483+
if verbose:
484+
print("[validate_model] -- done (final)")
485+
if dump_stats:
486+
with open(dump_stats, "w") as f:
487+
for k, v in sorted(summary.items()):
488+
f.write(f":{k}:{v};\n")
489+
return summary, data
490+
491+
if do_run:
492+
summary_valid, data = validate_onnx_model(data=data, quiet=quiet, verbose=verbose)
466493
summary.update(summary_valid)
467494

495+
if ortfusiontype and "onnx_filename" in data:
496+
assert (
497+
"configuration" in data
498+
), f"missing configuration in data, cannot run ort fusion for model_id={model_id}"
499+
config = data["configuration"]
500+
assert hasattr(
501+
config, "hidden_size"
502+
), f"Missing attribute hidden_size in configuration {config}"
503+
hidden_size = config.hidden_size
504+
assert hasattr(
505+
config, "num_attention_heads"
506+
), f"Missing attribute num_attention_heads in configuration {config}"
507+
num_attention_heads = config.num_attention_heads
508+
509+
model_types = ortfusiontype.split("|")
510+
for model_type in model_types:
511+
flavour = f"ort{model_type}"
512+
summary[f"version_{flavour}_hidden_size"] = hidden_size
513+
summary[f"version_{flavour}_num_attention_heads"] = num_attention_heads
514+
515+
begin = time.perf_counter()
516+
if verbose:
517+
print(f"[validate_model] run onnxruntime fusion for {model_type!r}")
518+
input_filename = data["onnx_filename"]
519+
output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx"
520+
run_ort_fusion(
521+
input_filename,
522+
output_path,
523+
model_type=model_type,
524+
num_attention_heads=num_attention_heads,
525+
hidden_size=hidden_size,
526+
)
527+
data[f"onnx_filename_{flavour}"] = output_path
528+
duration = time.perf_counter() - begin
529+
summary[f"time_ortfusion_{flavour}"] = duration
530+
if verbose:
531+
print(
532+
f"[validate_model] done {model_type!r} in {duration}, "
533+
f"saved into {output_path!r}"
534+
)
535+
536+
if do_run:
537+
summary_valid, data = validate_onnx_model(
538+
data=data, quiet=quiet, verbose=verbose, flavour=flavour
539+
)
540+
summary.update(summary_valid)
541+
468542
if verbose:
469543
print("[validate_model] -- done (final)")
470544
if dump_stats:
@@ -651,22 +725,27 @@ def validate_onnx_model(
651725
data: Dict[str, Any],
652726
quiet: bool = False,
653727
verbose: int = 0,
654-
optimization: Optional[str] = None,
728+
flavour: Optional[str] = None,
655729
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
656730
"""
657731
Verifies that an onnx model produces the same
658-
expected outputs.
732+
expected outputs. It uses ``data["onnx_filename]`` as the input
733+
onnx filename or ``data["onnx_filename_{flavour}]`` if *flavour*
734+
is specified.
659735
660736
:param data: dictionary with all the necessary inputs, the dictionary must
661737
contains keys ``model`` and ``inputs_export``
662738
:param quiet: catch exception or not
663739
:param verbose: verbosity
664-
:param optimization: optimization to do
740+
:param flavour: use a different version of the inputs
665741
:return: two dictionaries, one with some metrics,
666742
another one with whatever the function produces
667743
"""
668744
import onnxruntime
669745

746+
def _mk(key):
747+
return f"{key}_{flavour}" if flavour else key
748+
670749
summary = {}
671750
flat_inputs = flatten_object(data["inputs"], drop_keys=True)
672751
d = flat_inputs[0].get_device()
@@ -675,36 +754,42 @@ def validate_onnx_model(
675754
if d < 0
676755
else ["CUDAExecutionProvider", "CPUExecutionProvider"]
677756
)
678-
if "onnx_filename" in data:
679-
source = data["onnx_filename"]
680-
summary["onnx_filename"] = source
681-
summary["onnx_size"] = os.stat(source).st_size
757+
input_data_key = f"onnx_filename_{flavour}" if flavour else "onnx_filename"
758+
759+
if input_data_key in data:
760+
source = data[input_data_key]
761+
summary[input_data_key] = source
762+
summary[_mk("onnx_size")] = os.stat(source).st_size
682763
else:
764+
assert not flavour, f"flavour={flavour!r}, the filename must be saved."
683765
assert (
684766
"onnx_program" in data
685767
), f"onnx_program is missing from data which has {sorted(data)}"
686768
source = data["onnx_program"].model_proto.SerializeToString()
687769
assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb"
688-
summary["onnx_size"] = len(source)
770+
summary[_mk("onnx_size")] = len(source)
689771
if verbose:
690-
print(f"[validate_onnx_model] verify onnx model with providers {providers}...")
772+
print(
773+
f"[validate_onnx_model] verify onnx model with providers "
774+
f"{providers}..., flavour={flavour!r}"
775+
)
691776

692777
begin = time.perf_counter()
693778
if quiet:
694779
try:
695780
sess = onnxruntime.InferenceSession(source, providers=providers)
696781
except Exception as e:
697-
summary["ERR_onnx_ort_create"] = str(e)
698-
data["ERR_onnx_ort_create"] = e
699-
summary["time_onnx_ort_create"] = time.perf_counter() - begin
782+
summary[_mk("ERR_onnx_ort_create")] = str(e)
783+
data[_mk("ERR_onnx_ort_create")] = e
784+
summary[_mk("time_onnx_ort_create")] = time.perf_counter() - begin
700785
return summary, data
701786
else:
702787
sess = onnxruntime.InferenceSession(source, providers=providers)
703788

704-
summary["time_onnx_ort_create"] = time.perf_counter() - begin
705-
data["onnx_ort_sess"] = sess
789+
summary[_mk("time_onnx_ort_create")] = time.perf_counter() - begin
790+
data[_mk("onnx_ort_sess")] = sess
706791
if verbose:
707-
print("[validate_onnx_model] done (ort_session)")
792+
print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}")
708793

709794
# make_feeds
710795
if verbose:
@@ -718,7 +803,7 @@ def validate_onnx_model(
718803
)
719804
if verbose:
720805
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
721-
summary["onnx_ort_inputs"] = string_type(feeds, with_shape=True)
806+
summary[_mk("onnx_ort_inputs")] = string_type(feeds, with_shape=True)
722807
if verbose:
723808
print("[validate_onnx_model] done (make_feeds)")
724809

@@ -730,9 +815,9 @@ def validate_onnx_model(
730815
try:
731816
got = sess.run(None, feeds)
732817
except Exception as e:
733-
summary["ERR_onnx_ort_run"] = str(e)
734-
data["ERR_onnx_ort_run"] = e
735-
summary["time_onnx_ort_run"] = time.perf_counter() - begin
818+
summary[_mk("ERR_onnx_ort_run")] = str(e)
819+
data[_mk("ERR_onnx_ort_run")] = e
820+
summary[_mk("time_onnx_ort_run")] = time.perf_counter() - begin
736821
return summary, data
737822
else:
738823
got = sess.run(None, feeds)
@@ -745,7 +830,7 @@ def validate_onnx_model(
745830
if verbose:
746831
print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
747832
for k, v in disc.items():
748-
summary[f"disc_onnx_ort_run_{k}"] = v
833+
summary[_mk(f"disc_onnx_ort_run_{k}")] = v
749834
return summary, data
750835

751836

0 commit comments

Comments
 (0)