Skip to content

Commit f703ad6

Browse files
committed
adds speed up
1 parent cca5703 commit f703ad6

File tree

1 file changed

+57
-3
lines changed

1 file changed

+57
-3
lines changed

onnx_diagnostic/torch_models/validate.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,8 @@ def validate_model(
841841
)
842842
summary.update(summary_valid)
843843

844+
_compute_final_statistics(summary)
845+
844846
if verbose:
845847
print("[validate_model] -- done (final)")
846848
if dump_stats:
@@ -853,15 +855,24 @@ def validate_model(
853855
def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
854856
"""Computes some statistics on the model itself."""
855857
onx = onnx.load(onnx_filename, load_external_data=False)
858+
cache_functions = {(f.domain, f.name): f for f in onx.functions}
859+
local_domains = set(f.domain for f in onx.functions)
856860

857861
def node_iter(proto):
858862
if isinstance(proto, onnx.ModelProto):
859-
yield from node_iter(proto.graph)
860863
for f in proto.functions:
861864
yield from node_iter(f)
865+
yield from node_iter(proto.graph)
862866
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
863867
for node in proto.node:
864868
yield node
869+
870+
# Let's inline the function
871+
key = node.domain, node.op_type
872+
if key in cache_functions:
873+
yield from node_iter(cache_functions[key])
874+
875+
# Let's continue
865876
for att in node.attribute:
866877
if att.type == onnx.AttributeProto.GRAPH:
867878
yield from node_iter(att.g)
@@ -879,6 +890,11 @@ def node_iter(proto):
879890
n_nodes += 1
880891
if proto.op_type != "Constant":
881892
n_nodes_nocst += 1
893+
if proto.domain in local_domains:
894+
key = "n_node_local_function"
895+
if key not in counts:
896+
counts[key] = 0
897+
counts[key] += 1
882898
else:
883899
key = f"n_node_initializer_{proto.data_type}"
884900

@@ -1400,7 +1416,7 @@ def call_torch_export_onnx(
14001416
:return: two dictionaries, one with some metrics,
14011417
another one with whatever the function produces
14021418
"""
1403-
available = {None, "", "ir", "os_ort"}
1419+
available = {None, "", "ir", "os_ort", "ir+default"}
14041420
assert (
14051421
optimization in available
14061422
), f"unexpected value for optimization={optimization}, available={available}"
@@ -1490,11 +1506,31 @@ def call_torch_export_onnx(
14901506
print(epo)
14911507
print("[call_torch_export_onnx] -- End of ONNXProgram")
14921508

1493-
if optimization in {"ir", "os_ort"}:
1509+
if optimization in {"ir", "os_ort", "ir+default"}:
14941510
if verbose:
14951511
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
14961512
if optimization == "ir":
14971513
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1514+
elif optimization == "ir+default":
1515+
import onnxscript
1516+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
1517+
1518+
def _ir_default_opt(epo):
1519+
onnxscript.optimizer.optimize_ir(epo.model)
1520+
onx = epo.model_proto
1521+
# not very efficient
1522+
gr = GraphBuilder(
1523+
onx,
1524+
infer_shapes_options=True,
1525+
optimization_options=OptimizationOptions(patterns="default"),
1526+
)
1527+
cont = gr.to_onnx(large_model=True)
1528+
epo.model = cont.to_ir()
1529+
1530+
label, f_optim = "export_onnx_opt_ir_default", (
1531+
lambda epo=epo: _ir_default_opt(epo)
1532+
)
1533+
14981534
else:
14991535
import onnxscript
15001536
import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -1893,3 +1929,21 @@ def run_ort_fusion(
18931929
f"opt_ort_{model_type}_duration": duration,
18941930
f"opt_ort_{model_type}_duration_save": d,
18951931
}, {f"opt_ort_{model_type}": output_path}
1932+
1933+
1934+
def _compute_final_statistics(summary: Dict[str, Any]):
1935+
"""
1936+
Updates inline the list of statistics. It adds:
1937+
1938+
- speedup
1939+
"""
1940+
stats = {}
1941+
if (
1942+
"time_run_latency" in summary
1943+
and "time_run_onnx_ort_latency" in summary
1944+
and summary["time_run_onnx_ort_latency"] > 0
1945+
):
1946+
stats["stat_estimated_speedup_ort"] = (
1947+
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
1948+
)
1949+
summary.update(stats)

0 commit comments

Comments
 (0)