Skip to content

Commit f86c55d

Browse files
authored
Adds speed up to command line validate, avoid subfolder to be None (#227)
* adds speed up * fix none value * changes * fix mypy
1 parent cca5703 commit f86c55d

File tree

3 files changed

+73
-10
lines changed

3 files changed

+73
-10
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ Change Logs
44
0.7.12
55
++++++
66

7+
* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate
78
* :pr:`226`: fix input order for models created with modelbuilder
89

9-
1010
0.7.11
1111
++++++
1212

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def get_untrained_model_with_inputs(
189189
f"subfolder={subfolder!r}"
190190
)
191191
model = transformers.AutoModel.from_pretrained(
192-
model_id, subfolder=subfolder, trust_remote_code=True, **mkwargs
192+
model_id, subfolder=subfolder or "", trust_remote_code=True, **mkwargs
193193
)
194194
if verbose:
195195
print(

onnx_diagnostic/torch_models/validate.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264
return new_cfg
265265

266266

267-
def _preprocess_model_id(model_id, subfolder):
267+
def _preprocess_model_id(
268+
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
269+
) -> Tuple[str, Optional[str], bool, bool]:
268270
if subfolder or "//" not in model_id:
269-
return model_id, subfolder
271+
return model_id, subfolder, same_as_pretrained, use_pretrained
270272
spl = model_id.split("//")
273+
if spl[-1] == "pretrained":
274+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
271275
if spl[-1] in {"transformer", "vae"}:
272276
# known subfolder
273-
return "//".join(spl[:-1]), spl[-1]
274-
return model_id, subfolder
277+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
278+
return model_id, subfolder, same_as_pretrained, use_pretrained
275279

276280

277281
def validate_model(
@@ -384,7 +388,12 @@ def validate_model(
384388
if ``runtime == 'ref'``,
385389
``orteval10`` increases the verbosity.
386390
"""
387-
model_id, subfolder = _preprocess_model_id(model_id, subfolder)
391+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
392+
model_id,
393+
subfolder,
394+
same_as_pretrained=same_as_pretrained,
395+
use_pretrained=use_pretrained,
396+
)
388397
if isinstance(patch, bool):
389398
patch_kwargs = (
390399
dict(patch_transformers=True, patch_diffusers=True, patch=True)
@@ -841,6 +850,8 @@ def validate_model(
841850
)
842851
summary.update(summary_valid)
843852

853+
_compute_final_statistics(summary)
854+
844855
if verbose:
845856
print("[validate_model] -- done (final)")
846857
if dump_stats:
@@ -853,15 +864,24 @@ def validate_model(
853864
def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
854865
"""Computes some statistics on the model itself."""
855866
onx = onnx.load(onnx_filename, load_external_data=False)
867+
cache_functions = {(f.domain, f.name): f for f in onx.functions}
868+
local_domains = set(f.domain for f in onx.functions)
856869

857870
def node_iter(proto):
858871
if isinstance(proto, onnx.ModelProto):
859-
yield from node_iter(proto.graph)
860872
for f in proto.functions:
861873
yield from node_iter(f)
874+
yield from node_iter(proto.graph)
862875
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
863876
for node in proto.node:
864877
yield node
878+
879+
# Let's inline the function
880+
key = node.domain, node.op_type
881+
if key in cache_functions:
882+
yield from node_iter(cache_functions[key])
883+
884+
# Let's continue
865885
for att in node.attribute:
866886
if att.type == onnx.AttributeProto.GRAPH:
867887
yield from node_iter(att.g)
@@ -879,6 +899,11 @@ def node_iter(proto):
879899
n_nodes += 1
880900
if proto.op_type != "Constant":
881901
n_nodes_nocst += 1
902+
if proto.domain in local_domains:
903+
key = "n_node_local_function"
904+
if key not in counts:
905+
counts[key] = 0
906+
counts[key] += 1
882907
else:
883908
key = f"n_node_initializer_{proto.data_type}"
884909

@@ -1400,7 +1425,7 @@ def call_torch_export_onnx(
14001425
:return: two dictionaries, one with some metrics,
14011426
another one with whatever the function produces
14021427
"""
1403-
available = {None, "", "ir", "os_ort"}
1428+
available = {None, "", "ir", "os_ort", "ir+default"}
14041429
assert (
14051430
optimization in available
14061431
), f"unexpected value for optimization={optimization}, available={available}"
@@ -1490,11 +1515,31 @@ def call_torch_export_onnx(
14901515
print(epo)
14911516
print("[call_torch_export_onnx] -- End of ONNXProgram")
14921517

1493-
if optimization in {"ir", "os_ort"}:
1518+
if optimization in {"ir", "os_ort", "ir+default"}:
14941519
if verbose:
14951520
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
14961521
if optimization == "ir":
14971522
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1523+
elif optimization == "ir+default":
1524+
import onnxscript
1525+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
1526+
1527+
def _ir_default_opt(epo):
1528+
onnxscript.optimizer.optimize_ir(epo.model)
1529+
onx = epo.model_proto
1530+
# not very efficient
1531+
gr = GraphBuilder(
1532+
onx,
1533+
infer_shapes_options=True,
1534+
optimization_options=OptimizationOptions(patterns="default"),
1535+
)
1536+
cont = gr.to_onnx(large_model=True)
1537+
epo.model = cont.to_ir()
1538+
1539+
label, f_optim = "export_onnx_opt_ir_default", (
1540+
lambda epo=epo: _ir_default_opt(epo)
1541+
)
1542+
14981543
else:
14991544
import onnxscript
15001545
import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -1893,3 +1938,21 @@ def run_ort_fusion(
18931938
f"opt_ort_{model_type}_duration": duration,
18941939
f"opt_ort_{model_type}_duration_save": d,
18951940
}, {f"opt_ort_{model_type}": output_path}
1941+
1942+
1943+
def _compute_final_statistics(summary: Dict[str, Any]):
1944+
"""
1945+
Updates inline the list of statistics. It adds:
1946+
1947+
- speedup
1948+
"""
1949+
stats = {}
1950+
if (
1951+
"time_run_latency" in summary
1952+
and "time_run_onnx_ort_latency" in summary
1953+
and summary["time_run_onnx_ort_latency"] > 0
1954+
):
1955+
stats["stat_estimated_speedup_ort"] = (
1956+
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
1957+
)
1958+
summary.update(stats)

0 commit comments

Comments
 (0)