Skip to content

Commit 15208c6

Browse files
authored
adds a script to compare models (#221)
* script to compare models * more stats * manual seed * moves slow import to a location where it works better * lint
1 parent 0f9667b commit 15208c6

File tree

7 files changed

+230
-24
lines changed

7 files changed

+230
-24
lines changed

.github/workflows/documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ jobs:
118118
grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export'
119119
exit 1
120120
fi
121-
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string') ]]; then
121+
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache') ]]; then
122122
echo "Documentation produces warnings."
123-
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string'
123+
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache'
124124
exit 1
125125
fi
126126
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Compares two ONNX models.
3+
"""
4+
5+
print("-- import onnx")
6+
import onnx
7+
8+
print("-- import onnx.helper")
9+
from onnx.helper import tensor_dtype_to_np_dtype
10+
11+
print("-- import onnxruntime")
12+
import onnxruntime
13+
14+
print("-- import torch")
15+
import torch
16+
17+
print("-- import transformers")
18+
import transformers
19+
20+
print("-- import huggingface_hub")
21+
import huggingface_hub
22+
23+
print("-- import onnx-diagnostic.helper")
24+
from onnx_diagnostic.helpers.helper import flatten_object, string_type, max_diff, string_diff
25+
26+
print("-- import onnx-diagnostic.torch_models.hghub")
27+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
28+
29+
print("-- done")
30+
31+
model_id = "arnir0/Tiny-LLM"
32+
onnx1 = (
33+
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op20/"
34+
"arnir0_Tiny-LLM-custom-default-f16-cuda-op20.onnx"
35+
)
36+
onnx2 = (
37+
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op21/"
38+
"arnir0_Tiny-LLM-custom-default-f16-cuda-op21.onnx"
39+
)
40+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
41+
42+
print(f"-- load {onnx1!r}")
43+
onx1 = onnx.load(onnx1)
44+
print(f"-- load {onnx2!r}")
45+
onx2 = onnx.load(onnx2)
46+
47+
print(f"-- getting inputs for model_id {model_id!r}")
48+
data = get_untrained_model_with_inputs(model_id)
49+
inputs = data["inputs"]
50+
print(f"-- inputs: {string_type(inputs, with_shape=True)}")
51+
flatten_inputs = flatten_object(inputs, drop_keys=True)
52+
print(f"-- flat inputs: {string_type(flatten_inputs, with_shape=True)}")
53+
54+
names = [i.name for i in onx1.graph.input]
55+
itypes = [i.type.tensor_type.elem_type for i in onx1.graph.input]
56+
assert names == [
57+
i.name for i in onx2.graph.input
58+
], f"Not the same names for both models {names} != {[i.name for i in onx2.graph.input]}"
59+
feeds = {
60+
n: t.numpy().astype(tensor_dtype_to_np_dtype(itype))
61+
for n, itype, t in zip(names, itypes, flatten_inputs)
62+
}
63+
print(f"-- feeds: {string_type(feeds, with_shape=True)}")
64+
65+
print(f"-- creating session 1 from {onnx1!r}")
66+
opts = onnxruntime.SessionOptions()
67+
opts.optimized_model_filepath = "debug1_full.onnx"
68+
opts.log_severity_level = 0
69+
opts.log_verbosity_level = 0
70+
sess1 = onnxruntime.InferenceSession(onnx1, opts, providers=providers)
71+
print(f"-- creating session 2 from {onnx2!r}")
72+
opts.optimized_model_filepath = "debug2_full.onnx"
73+
opts.log_severity_level = 0
74+
opts.log_verbosity_level = 0
75+
sess2 = onnxruntime.InferenceSession(onnx2, opts, providers=providers)
76+
77+
print("-- run session1")
78+
expected1 = sess1.run(None, feeds)
79+
print(f"-- got {string_type(expected1, with_shape=True)}")
80+
print("-- run session2")
81+
expected2 = sess2.run(None, feeds)
82+
print(f"-- got {string_type(expected2, with_shape=True)}")
83+
84+
print("-- compute differences")
85+
diff = max_diff(expected1, expected2)
86+
print(f"-- diff={string_diff(diff)}")
87+
88+
89+
def get_names(onx: onnx.ModelProto) -> list[str]:
90+
names = []
91+
for node in onx.graph.node:
92+
for o in node.output:
93+
names.append((o, node.op_type, node.name))
94+
return names
95+
96+
97+
if diff["abs"] > 0.1:
98+
print("--")
99+
print("-- import select_model_inputs_outputs")
100+
from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
101+
102+
print("-- looking into intermediate results")
103+
names1 = get_names(onx1)
104+
names2 = get_names(onx1)
105+
common = [n for n in names1 if n in (set(names1) & set(names2))]
106+
print(f"-- {len(common)} names / {len(names1)}-{len(names2)}")
107+
print(f"-- first names {common[:5]}")
108+
for name, op_type, op_name in common:
109+
x1 = select_model_inputs_outputs(onx1, [name])
110+
x2 = select_model_inputs_outputs(onx2, [name])
111+
s1 = onnxruntime.InferenceSession(x1.SerializeToString(), providers=providers)
112+
s2 = onnxruntime.InferenceSession(x2.SerializeToString(), providers=providers)
113+
e1 = s1.run(None, feeds)
114+
e2 = s2.run(None, feeds)
115+
diff = max_diff(e1, e2)
116+
print(
117+
f"-- name={name!r}: diff={string_diff(diff)} "
118+
f"- op_type={op_type!r}, op_name={op_name!r}"
119+
)
120+
if diff["abs"] > 0.1:
121+
opts = onnxruntime.SessionOptions()
122+
opts.optimized_model_filepath = "debug1.onnx"
123+
onnxruntime.InferenceSession(x1.SerializeToString(), opts, providers=providers)
124+
opts.optimized_model_filepath = "debug2.onnx"
125+
onnxruntime.InferenceSession(x2.SerializeToString(), opts, providers=providers)
126+
print("--")
127+
print("-- break here")
128+
print(f"-- feeds {string_type(feeds, with_shape=True)}")
129+
print(f"-- e1={string_type(e1, with_shape=True, with_min_max=True)}")
130+
print(f"-- e2={string_type(e2, with_shape=True, with_min_max=True)}")
131+
break

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
474474
)
475475
parser.add_argument(
476476
"--runtime",
477-
choices=["onnxruntime", "torch", "ref"],
477+
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
478478
default="onnxruntime",
479479
help="onnx runtime to use, `onnxruntime` by default",
480480
)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
import transformers
55
import transformers.cache_utils
66

7-
try:
8-
from transformers.models.mamba.modeling_mamba import MambaCache
9-
except ImportError:
10-
from transformers.cache_utils import MambaCache
11-
127

138
class CacheKeyValue:
149
"""
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
354349
)
355350

356351

357-
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
352+
def make_mamba_cache(
353+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
354+
) -> "MambaCache": # noqa: F821
358355
"Creates a ``MambaCache``."
356+
# import is moved here because this part is slow.
357+
try:
358+
from transformers.models.mamba.modeling_mamba import MambaCache
359+
except ImportError:
360+
from transformers.cache_utils import MambaCache
359361
dtype = key_value_pairs[0][0].dtype
360362

361363
class _config:

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def get_untrained_model_with_inputs(
228228
f"and use_pretrained=True."
229229
)
230230

231+
seed = int(os.environ.get("SEED", "17"))
232+
torch.manual_seed(seed)
231233
try:
232234
if type(config) is dict:
233235
model = cls_model(**config)
@@ -239,6 +241,8 @@ def get_untrained_model_with_inputs(
239241
) from e
240242

241243
# input kwargs
244+
seed = int(os.environ.get("SEED", "17")) + 1
245+
torch.manual_seed(seed)
242246
kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
243247
if verbose:
244248
print(f"[get_untrained_model_with_inputs] use fct={fct}")

onnx_diagnostic/torch_models/validate.py

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import time
88
import numpy as np
99
import onnx
10-
import onnxscript
11-
import onnxscript.rewriter.ort_fusions as ort_fusions
1210
import torch
1311
from ..export import CoupleInputsDynamicShapes
1412
from ..helpers import max_diff, string_type, string_diff
@@ -249,6 +247,7 @@ def _quiet_or_not_quiet(
249247
summary[f"time_{suffix}_latency_std"] = a.std()
250248
summary[f"time_{suffix}_latency_min"] = a.min()
251249
summary[f"time_{suffix}_latency_min"] = a.max()
250+
summary[f"time_{suffix}_n"] = len(a)
252251
return res
253252

254253

@@ -337,7 +336,8 @@ def validate_model(
337336
:param subfolder: version or subfolders to uses when retrieving a model id
338337
:param opset: onnx opset to use for the conversion
339338
:param runtime: onnx runtime to use to check about discrepancies,
340-
only if `do_run` is true
339+
possible values ``onnxruntime``, ``torch``, ``orteval``,
340+
``orteval10``, ``ref`` only if `do_run` is true
341341
:param repeat: number of time to measure the model
342342
:param warmup: warmup the model first
343343
:param inputs2: checks that the second set of inputs is reunning as well,
@@ -364,7 +364,13 @@ def validate_model(
364364
365365
The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
366366
exported model returns the same outputs as the original one, otherwise,
367-
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
367+
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
368+
if ``runtime == 'torch'`` or
369+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
370+
if ``runtime == 'orteval'`` or
371+
:class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
372+
if ``runtime == 'ref'``,
373+
``orteval10`` increases the verbosity.
368374
"""
369375
if isinstance(patch, bool):
370376
patch_kwargs = (
@@ -846,15 +852,24 @@ def node_iter(proto):
846852
raise NotImplementedError(f"Unexpected type={type(proto)}")
847853

848854
counts: Dict[str, Union[float, int]] = {}
855+
n_nodes = 0
856+
n_nodes_nocst = 0
849857
for proto in node_iter(onx):
850858
if isinstance(proto, onnx.NodeProto):
851859
key = f"n_node_{proto.op_type}"
860+
n_nodes += 1
861+
if proto.op_type != "Constant":
862+
n_nodes_nocst += 1
852863
else:
853864
key = f"n_node_initializer_{proto.data_type}"
854865

855866
if key not in counts:
856867
counts[key] = 0
857868
counts[key] += 1
869+
870+
counts["n_node_nodes"] = n_nodes
871+
counts["n_node_nodes_nocst"] = n_nodes_nocst
872+
counts["n_node_functions"] = len(onx.functions)
858873
return counts
859874

860875

@@ -1155,7 +1170,7 @@ def validate_onnx_model(
11551170
:param quiet: catch exception or not
11561171
:param verbose: verbosity
11571172
:param flavour: use a different version of the inputs
1158-
:param runtime: onnx runtime to use, onnxruntime or torch
1173+
:param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
11591174
:param repeat: run that number of times the model
11601175
:param warmup: warmup the model
11611176
:param inputs2: to validate the model on the second input set
@@ -1202,23 +1217,66 @@ def _mk(key, flavour=flavour):
12021217
f"{providers}..., flavour={flavour!r}"
12031218
)
12041219

1205-
if runtime != "onnxruntime":
1220+
if runtime == "onnxruntime":
1221+
if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
1222+
opts = onnxruntime.SessionOptions()
1223+
opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
1224+
if verbose:
1225+
print(
1226+
f"[validate_onnx_model] saved optimized onnxruntime "
1227+
f"in {opts.optimized_model_filepath!r}"
1228+
)
1229+
onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
1230+
if verbose:
1231+
print("[validate_onnx_model] -- done")
1232+
1233+
if verbose:
1234+
print("[validate_onnx_model] runtime is onnxruntime")
1235+
cls_runtime = lambda model, providers: onnxruntime.InferenceSession(
1236+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1237+
providers=providers,
1238+
)
1239+
elif runtime == "torch":
12061240
from ..reference import TorchOnnxEvaluator
12071241

1208-
cls_runtime = (
1209-
(
1210-
lambda model, providers: onnxruntime.InferenceSession(
1211-
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1212-
providers=providers,
1242+
if verbose:
1243+
print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
1244+
cls_runtime = (
1245+
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1246+
model, providers=providers, verbose=max(verbose - 1, 0)
12131247
)
12141248
)
1215-
if runtime == "onnxruntime"
1216-
else (
1217-
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1249+
elif runtime == "orteval":
1250+
from ..reference import OnnxruntimeEvaluator
1251+
1252+
if verbose:
1253+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
1254+
cls_runtime = (
1255+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
12181256
model, providers=providers, verbose=max(verbose - 1, 0)
12191257
)
12201258
)
1221-
)
1259+
elif runtime == "orteval10":
1260+
from ..reference import OnnxruntimeEvaluator
1261+
1262+
if verbose:
1263+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
1264+
cls_runtime = (
1265+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
1266+
model, providers=providers, verbose=10
1267+
)
1268+
)
1269+
elif runtime == "ref":
1270+
from ..reference import ExtendedReferenceEvaluator
1271+
1272+
if verbose:
1273+
print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
1274+
cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
1275+
model, verbose=max(verbose - 1, 0)
1276+
)
1277+
else:
1278+
raise ValueError(f"Unexpecteed runtime={runtime!r}")
1279+
12221280
sess = _quiet_or_not_quiet(
12231281
quiet,
12241282
_mk("create_onnx_ort"),
@@ -1399,6 +1457,8 @@ def call_torch_export_onnx(
13991457
if optimization == "ir":
14001458
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
14011459
else:
1460+
import onnxscript
1461+
import onnxscript.rewriter.ort_fusions as ort_fusions
14021462

14031463
def _os_ort_optim(epo):
14041464
onnxscript.optimizer.optimize_ir(epo.model)
@@ -1683,6 +1743,9 @@ def call_torch_export_custom(
16831743
print("[call_torch_export_custom] done (export)")
16841744

16851745
if os_ort:
1746+
import onnxscript
1747+
import onnxscript.rewriter.ort_fusions as ort_fusions
1748+
16861749
if verbose:
16871750
print("[call_torch_export_custom] conversion to IR...")
16881751
begin = time.perf_counter()

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ disable_error_code = ["arg-type", "assignment", "import-untyped", "misc", "name-
2828
module = ["onnx_diagnostic.helpers.args_helper"]
2929
disable_error_code = ["arg-type", "call-overload", "index"]
3030

31+
[[tool.mypy.overrides]]
32+
module = ["onnx_diagnostic.helpers.cache_helper"]
33+
disable_error_code = ["name-defined"]
34+
3135
[[tool.mypy.overrides]]
3236
module = ["onnx_diagnostic.helpers.helper"]
3337
disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr"]
@@ -123,6 +127,7 @@ select = [
123127
"_doc/examples/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
124128
"_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
125129
"_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
130+
"_scripts/compare_model_execution.py" = ["E402", "F401"]
126131
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
127132
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
128133
"onnx_diagnostic/export/__init__.py" = ["F401"]
@@ -131,6 +136,7 @@ select = [
131136
"onnx_diagnostic/reference/torch_ops/__init__.py" = ["F401"]
132137
"onnx_diagnostic/torch_models/hghub/__init__.py" = ["F401"]
133138
"onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py" = ["PIE804"]
139+
"onnx_diagnostic/torch_models/validate.py" = ["E731"]
134140
"onnx_diagnostic/torch_export_patches/__init__.py" = ["F401"]
135141
"onnx_diagnostic/torch_export_patches/patches/__init__.py" = ["F401"]
136142
"onnx_diagnostic/torch_models/llms.py" = ["F401"]

0 commit comments

Comments
 (0)