Skip to content

Commit 390e03f

Browse files
committed
script to compare models
1 parent 0f9667b commit 390e03f

File tree

4 files changed

+203
-16
lines changed

4 files changed

+203
-16
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Compares two ONNX models.
3+
"""
4+
5+
print("-- import onnx")
6+
import onnx
7+
8+
print("-- import onnx.helper")
9+
from onnx.helper import tensor_dtype_to_np_dtype
10+
11+
print("-- import onnxruntime")
12+
import onnxruntime
13+
14+
print("-- import torch")
15+
import torch
16+
17+
print("-- import transformers")
18+
import transformers
19+
20+
print("-- import huggingface_hub")
21+
import huggingface_hub
22+
23+
print("-- import onnx-diagnostic.helper")
24+
from onnx_diagnostic.helpers.helper import flatten_object, string_type, max_diff, string_diff
25+
26+
print("-- import onnx-diagnostic.torch_models")
27+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
28+
29+
print("-- done")
30+
31+
model_id = "arnir0/Tiny-LLM"
32+
onnx1 = (
33+
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op20/"
34+
"arnir0_Tiny-LLM-custom-default-f16-cuda-op20.onnx"
35+
)
36+
onnx2 = (
37+
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op21/"
38+
"arnir0_Tiny-LLM-custom-default-f16-cuda-op21.onnx"
39+
)
40+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
41+
42+
print(f"-- load {onnx1!r}")
43+
onx1 = onnx.load(onnx1)
44+
print(f"-- load {onnx2!r}")
45+
onx2 = onnx.load(onnx2)
46+
47+
print(f"-- getting inputs for model_id {model_id!r}")
48+
data = get_untrained_model_with_inputs(model_id)
49+
inputs = data["inputs"]
50+
print(f"-- inputs: {string_type(inputs, with_shape=True)}")
51+
flatten_inputs = flatten_object(inputs, drop_keys=True)
52+
print(f"-- flat inputs: {string_type(flatten_inputs, with_shape=True)}")
53+
54+
names = [i.name for i in onx1.graph.input]
55+
itypes = [i.type.tensor_type.elem_type for i in onx1.graph.input]
56+
assert names == [
57+
i.name for i in onx2.graph.input
58+
], f"Not the same names for both models {names} != {[i.name for i in onx2.graph.input]}"
59+
feeds = {
60+
n: t.numpy().astype(tensor_dtype_to_np_dtype(itype))
61+
for n, itype, t in zip(names, itypes, flatten_inputs)
62+
}
63+
print(f"-- feeds: {string_type(feeds, with_shape=True)}")
64+
65+
print(f"-- creating session 1 from {onnx1!r}")
66+
opts = onnxruntime.SessionOptions()
67+
opts.optimized_model_filepath = "debug1_full.onnx"
68+
sess1 = onnxruntime.InferenceSession(onnx1, opts, providers=providers)
69+
print(f"-- creating session 2 from {onnx2!r}")
70+
opts.optimized_model_filepath = "debug2_full.onnx"
71+
sess2 = onnxruntime.InferenceSession(onnx2, opts, providers=providers)
72+
73+
for n in ["debug1_full.onnx", "debug2_full.onnx"]:
74+
x = onnx.load(n, load_external_data=False)
75+
onnx.save(x, n.replace(".onnx", "-ext.onnx"), save_as_external_data=True)
76+
77+
print("-- run session1")
78+
expected1 = sess1.run(None, feeds)
79+
print(f"-- got {string_type(expected1, with_shape=True)}")
80+
print("-- run session2")
81+
expected2 = sess2.run(None, feeds)
82+
print(f"-- got {string_type(expected2, with_shape=True)}")
83+
84+
print("-- compute differences")
85+
diff = max_diff(expected1, expected2)
86+
print(f"-- diff={string_diff(diff)}")
87+
88+
89+
def get_names(onx: onnx.ModelProto) -> list[str]:
90+
names = []
91+
for node in onx.graph.node:
92+
for o in node.output:
93+
names.append((o, node.op_type, node.name))
94+
return names
95+
96+
97+
if diff["abs"] > 0.1:
98+
print("--")
99+
print("-- import select_model_inputs_outputs")
100+
from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
101+
102+
print("-- looking into intermediate results")
103+
names1 = get_names(onx1)
104+
names2 = get_names(onx1)
105+
common = [n for n in names1 if n in (set(names1) & set(names2))]
106+
print(f"-- {len(common)} names / {len(names1)}-{len(names2)}")
107+
print(f"-- first names {common[:5]}")
108+
for name, op_type, op_name in common:
109+
x1 = select_model_inputs_outputs(onx1, [name])
110+
x2 = select_model_inputs_outputs(onx2, [name])
111+
s1 = onnxruntime.InferenceSession(x1.SerializeToString(), providers=providers)
112+
s2 = onnxruntime.InferenceSession(x2.SerializeToString(), providers=providers)
113+
e1 = s1.run(None, feeds)
114+
e2 = s2.run(None, feeds)
115+
diff = max_diff(e1, e2)
116+
print(
117+
f"-- name={name!r}: diff={string_diff(diff)} "
118+
f"- op_type={op_type!r}, op_name={op_name!r}"
119+
)
120+
if diff["abs"] > 0.1:
121+
opts = onnxruntime.SessionOptions()
122+
opts.optimized_model_filepath = "debug1.onnx"
123+
onnxruntime.InferenceSession(x1.SerializeToString(), opts, providers=providers)
124+
opts.optimized_model_filepath = "debug2.onnx"
125+
onnxruntime.InferenceSession(x2.SerializeToString(), opts, providers=providers)
126+
print("--")
127+
print("-- break here")
128+
print(f"-- feeds {string_type(feeds, with_shape=True)}")
129+
print(f"-- e1={string_type(e1, with_shape=True, with_min_max=True)}")
130+
print(f"-- e2={string_type(e2, with_shape=True, with_min_max=True)}")
131+
break

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
474474
)
475475
parser.add_argument(
476476
"--runtime",
477-
choices=["onnxruntime", "torch", "ref"],
477+
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
478478
default="onnxruntime",
479479
help="onnx runtime to use, `onnxruntime` by default",
480480
)

onnx_diagnostic/torch_models/validate.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import time
88
import numpy as np
99
import onnx
10-
import onnxscript
11-
import onnxscript.rewriter.ort_fusions as ort_fusions
1210
import torch
1311
from ..export import CoupleInputsDynamicShapes
1412
from ..helpers import max_diff, string_type, string_diff
@@ -249,6 +247,7 @@ def _quiet_or_not_quiet(
249247
summary[f"time_{suffix}_latency_std"] = a.std()
250248
summary[f"time_{suffix}_latency_min"] = a.min()
251249
summary[f"time_{suffix}_latency_min"] = a.max()
250+
summary[f"time_{suffix}_n"] = len(a)
252251
return res
253252

254253

@@ -337,7 +336,8 @@ def validate_model(
337336
:param subfolder: version or subfolders to uses when retrieving a model id
338337
:param opset: onnx opset to use for the conversion
339338
:param runtime: onnx runtime to use to check about discrepancies,
340-
only if `do_run` is true
339+
possible values ``onnxruntime``, ``torch``, ``orteval``,
340+
``orteval10``, ``ref`` only if `do_run` is true
341341
:param repeat: number of time to measure the model
342342
:param warmup: warmup the model first
343343
:param inputs2: checks that the second set of inputs is reunning as well,
@@ -364,7 +364,13 @@ def validate_model(
364364
365365
The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
366366
exported model returns the same outputs as the original one, otherwise,
367-
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
367+
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
368+
if ``runtime == 'torch'`` or
369+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
370+
if ``runtime == 'orteval'`` or
371+
:class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
372+
if ``runtime == 'ref'``,
373+
``orteval10`` increases the verbosity.
368374
"""
369375
if isinstance(patch, bool):
370376
patch_kwargs = (
@@ -1155,7 +1161,7 @@ def validate_onnx_model(
11551161
:param quiet: catch exception or not
11561162
:param verbose: verbosity
11571163
:param flavour: use a different version of the inputs
1158-
:param runtime: onnx runtime to use, onnxruntime or torch
1164+
:param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
11591165
:param repeat: run that number of times the model
11601166
:param warmup: warmup the model
11611167
:param inputs2: to validate the model on the second input set
@@ -1202,23 +1208,66 @@ def _mk(key, flavour=flavour):
12021208
f"{providers}..., flavour={flavour!r}"
12031209
)
12041210

1205-
if runtime != "onnxruntime":
1211+
if runtime == "onnxruntime":
1212+
if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
1213+
opts = onnxruntime.SessionOptions()
1214+
opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
1215+
if verbose:
1216+
print(
1217+
f"[validate_onnx_model] saved optimized onnxruntime "
1218+
f"in {opts.optimized_model_filepath!r}"
1219+
)
1220+
onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
1221+
if verbose:
1222+
print("[validate_onnx_model] -- done")
1223+
1224+
if verbose:
1225+
print("[validate_onnx_model] runtime is onnxruntime")
1226+
cls_runtime = lambda model, providers: onnxruntime.InferenceSession(
1227+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1228+
providers=providers,
1229+
)
1230+
elif runtime == "torch":
12061231
from ..reference import TorchOnnxEvaluator
12071232

1208-
cls_runtime = (
1209-
(
1210-
lambda model, providers: onnxruntime.InferenceSession(
1211-
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1212-
providers=providers,
1233+
if verbose:
1234+
print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
1235+
cls_runtime = (
1236+
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1237+
model, providers=providers, verbose=max(verbose - 1, 0)
12131238
)
12141239
)
1215-
if runtime == "onnxruntime"
1216-
else (
1217-
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1240+
elif runtime == "orteval":
1241+
from ..reference import OnnxruntimeEvaluator
1242+
1243+
if verbose:
1244+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
1245+
cls_runtime = (
1246+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
12181247
model, providers=providers, verbose=max(verbose - 1, 0)
12191248
)
12201249
)
1221-
)
1250+
elif runtime == "orteval10":
1251+
from ..reference import OnnxruntimeEvaluator
1252+
1253+
if verbose:
1254+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
1255+
cls_runtime = (
1256+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
1257+
model, providers=providers, verbose=10
1258+
)
1259+
)
1260+
elif runtime == "ref":
1261+
from ..reference import ExtendedReferenceEvaluator
1262+
1263+
if verbose:
1264+
print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
1265+
cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
1266+
model, verbose=max(verbose - 1, 0)
1267+
)
1268+
else:
1269+
raise ValueError(f"Unexpecteed runtime={runtime!r}")
1270+
12221271
sess = _quiet_or_not_quiet(
12231272
quiet,
12241273
_mk("create_onnx_ort"),
@@ -1399,6 +1448,8 @@ def call_torch_export_onnx(
13991448
if optimization == "ir":
14001449
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
14011450
else:
1451+
import onnxscript
1452+
import onnxscript.rewriter.ort_fusions as ort_fusions
14021453

14031454
def _os_ort_optim(epo):
14041455
onnxscript.optimizer.optimize_ir(epo.model)
@@ -1683,6 +1734,9 @@ def call_torch_export_custom(
16831734
print("[call_torch_export_custom] done (export)")
16841735

16851736
if os_ort:
1737+
import onnxscript
1738+
import onnxscript.rewriter.ort_fusions as ort_fusions
1739+
16861740
if verbose:
16871741
print("[call_torch_export_custom] conversion to IR...")
16881742
begin = time.perf_counter()

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ select = [
123123
"_doc/examples/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
124124
"_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
125125
"_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
126+
"_scripts/compare_model_execution.py" = ["E402", "F401"]
126127
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
127128
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
128129
"onnx_diagnostic/export/__init__.py" = ["F401"]
@@ -131,6 +132,7 @@ select = [
131132
"onnx_diagnostic/reference/torch_ops/__init__.py" = ["F401"]
132133
"onnx_diagnostic/torch_models/hghub/__init__.py" = ["F401"]
133134
"onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py" = ["PIE804"]
135+
"onnx_diagnostic/torch_models/validate.py" = ["E731"]
134136
"onnx_diagnostic/torch_export_patches/__init__.py" = ["F401"]
135137
"onnx_diagnostic/torch_export_patches/patches/__init__.py" = ["F401"]
136138
"onnx_diagnostic/torch_models/llms.py" = ["F401"]

0 commit comments

Comments
 (0)