Skip to content

Commit 3ff0c54

Browse files
authored
noinline for command line validate (#139)
* noinline * ut * better graphs * example * add missing BartModel
1 parent 474329e commit 3ff0c54

File tree

9 files changed

+162
-61
lines changed

9 files changed

+162
-61
lines changed

_doc/technical/plot_layer_norm_discrepancies.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,26 @@
66
:ref:`l-plot-parallelized-reduction`, reduction operations
77
are sensitive to parallelization.
88
9-
We consider a small model including a layer normalization
10-
followed by a matrix multiplication and we show that replacing
11-
a kernel by another one may significantly impact the output.
9+
Methodology
10+
+++++++++++
11+
12+
We consider a simple model with a LayerNormalization followed by a MatMul.
13+
Each operator can be run with :epkg:`onnxruntime` or :epkg:`pytorch`.
14+
We compare the four combinations.
1215
1316
The model
1417
+++++++++
1518
"""
1619

1720
import itertools
21+
import numpy as np
1822
import pandas
1923
import onnx
2024
import onnx.helper as oh
2125
import onnxruntime
2226
import torch
2327
from onnx_array_api.plotting.graphviz_helper import plot_dot
28+
from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title
2429
from onnx_diagnostic.ext_test_case import unit_test_going
2530
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
2631
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype
@@ -79,6 +84,8 @@ def make_feeds(last_dim: int):
7984

8085

8186
def cast_feeds(itype, provider, feeds):
87+
ttype = onnx_dtype_to_torch_dtype(itype)
88+
np_dtype = onnx_dtype_to_np_dtype(itype)
8289
np_feeds = {k: v.detach().numpy() for k, v in feeds.items()}
8390
if provider == "CUDA":
8491
if not torch.cuda.is_available():
@@ -101,8 +108,6 @@ def cast_feeds(itype, provider, feeds):
101108
baseline = {}
102109

103110
for provider, itype in itertools.product(["CPU", "CUDA"], [TFLOAT, TFLOAT16]):
104-
ttype = onnx_dtype_to_torch_dtype(itype)
105-
np_dtype = onnx_dtype_to_np_dtype(itype)
106111
tch_feeds, ort_feeds = cast_feeds(itype, provider, feeds)
107112
if tch_feeds is None:
108113
continue
@@ -143,13 +148,34 @@ def cast_feeds(itype, provider, feeds):
143148
# %%
144149
# Visually.
145150

146-
df["abs"].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
151+
save_fig(
152+
rotate_align(
153+
df[["abs"]].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
154+
),
155+
"plot_layer_norm_discrepancies_1.png",
156+
)
147157

148158
# %%
149159
# The discrepancies are significant on CUDA, higher for float16.
150160
# Let's see which operator is responsible for them,
151161
# *LayerNormalization* or *MatMul*.
152162

163+
# %%
164+
# Distribution of the results
165+
# +++++++++++++++++++++++++++
166+
167+
tensor = baseline[TFLOAT16, "CPU", "ort"][0].ravel().astype(np.float32)
168+
print(pandas.DataFrame({"expected": tensor}).describe())
169+
170+
# %%
171+
# Histogram.
172+
173+
save_fig(
174+
title(plot_histogram(tensor), "Distribution of the computed results"),
175+
"plot_layer_norm_discrepancies_hist.png",
176+
)
177+
178+
153179
# %%
154180
# The discrepancies come from?
155181
# ++++++++++++++++++++++++++++
@@ -159,19 +185,18 @@ def cast_feeds(itype, provider, feeds):
159185
data = []
160186

161187
for mod, provider, itype in itertools.product(
162-
["ORT-TORCH", "TORCH-ORT"], ["CPU", "CUDA"], [TFLOAT, TFLOAT16]
188+
["ORT-ORT", "ORT-TORCH", "TORCH-ORT", "TORCH-TORCH"], ["CPU", "CUDA"], [TFLOAT, TFLOAT16]
163189
):
164190
ttype = onnx_dtype_to_torch_dtype(itype)
165191
np_dtype = onnx_dtype_to_np_dtype(itype)
166192
tch_feeds, _ = cast_feeds(itype, provider, feeds)
167193
if tch_feeds is None:
168194
continue
169195

196+
ker1, ker2 = mod.split("-")
170197
custom_kernels = (
171-
{("", "LayerNormalization"): LayerNormalizationOrt}
172-
if mod == "ORT-TORCH"
173-
else {("", "MatMul"): MatMulOrt}
174-
)
198+
{("", "LayerNormalization"): LayerNormalizationOrt} if ker1 == "ORT" else {}
199+
) | ({("", "MatMul"): MatMulOrt} if ker2 == "ORT" else {})
175200

176201
model = get_model(itype)
177202
print()
@@ -200,13 +225,27 @@ def cast_feeds(itype, provider, feeds):
200225
)
201226

202227
# %%
203-
df = pandas.DataFrame(data).set_index(["model", "provider", "dtype"])
228+
df = pandas.DataFrame(data).set_index(["dtype", "provider", "model"])
204229
df = df.sort_index()
205230
print(df)
206231

207232
# %%
208233
# Visually.
209234

210-
df[["diff_ort", "diff_torch"]].plot.bar(
211-
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
235+
save_fig(
236+
rotate_align(
237+
df[["diff_ort", "diff_torch"]].plot.bar(
238+
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B",
239+
figsize=(10, 4),
240+
)
241+
),
242+
"plot_layer_norm_discrepancies_2.png",
212243
)
244+
245+
# %%
246+
# Conclusion
247+
# ++++++++++
248+
#
249+
# :epkg:`torch` seems able to replicate the same results if the same computation
250+
# is run multiple times. :epkg:`onnxruntime` is only able to do that on CUDA.
251+
# With float16 and CUDA, LayerNormalization seems to introduce some discrepancies.

_doc/technical/plot_parallelized_reduction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
2424
With :math:`\\mathbb{E}X = mean(X)`,
2525
:math:`\\mathbb{V}X = mean\\left(\\left(X - mean(X)\\right)^2\\right)`.
26+
27+
Methodology
28+
+++++++++++
29+
30+
**Permutation should not change the average.**
31+
2632
We draw 128 random permutations of X. The average or mean should not change.
2733
And the normalized vector should have the same values. In the first case, we compute
2834
the difference between the highest and the lowest values obtained for the average.
@@ -188,6 +194,7 @@ def make_value(base, value):
188194
# Visually.
189195

190196
ax = df.plot.bar(logy=True)
197+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
191198
fig = ax.get_figure()
192199
fig.savefig("plot_parallelized_reduction.png")
193200

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
4+
5+
6+
class TestPatchRewrite(ExtTestCase):
7+
def test_code_needing_rewriting(self):
8+
res = code_needing_rewriting("BartModel")
9+
self.assertEqual(len(res), 2)
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_validate_model_custom_torch(self):
176176
mid,
177177
do_run=True,
178178
verbose=10,
179-
exporter="custom-inline",
179+
exporter="custom-noinline",
180180
dump_folder="dump_test_validate_model_custom_torch",
181181
patch=True,
182182
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,

k.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

onnx_diagnostic/doc.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from typing import Optional
2+
import numpy as np
3+
4+
15
def reset_torch_transformers(gallery_conf, fname):
26
"Resets torch dynamo for :epkg:`sphinx-gallery`."
37
import matplotlib.pyplot as plt
@@ -30,3 +34,45 @@ def plot_legend(
3034
ax.grid(False)
3135
ax.set_axis_off()
3236
return ax
37+
38+
39+
def rotate_align(ax, angle=15, align="right"):
40+
"""Rotates x-label and align them to thr right. Returns ax."""
41+
for label in ax.get_xticklabels():
42+
label.set_rotation(angle)
43+
label.set_horizontalalignment(align)
44+
return ax
45+
46+
47+
def save_fig(ax, name: str):
48+
"""Applies ``tight_layout`` and saves the figures. Returns ax."""
49+
import matplotlib.pyplot as plt
50+
51+
plt.tight_layout()
52+
fig = ax.get_figure()
53+
fig.savefig(name)
54+
return ax
55+
56+
57+
def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
58+
"Adds a title to axes and returns them."
59+
ax.set_title(title)
60+
return ax
61+
62+
63+
def plot_histogram(
64+
tensor: np.ndarray,
65+
ax: Optional["plt.axes"] = None, # noqa: F821
66+
bins: int = 30,
67+
color: str = "orange",
68+
alpha: float = 0.7,
69+
) -> "plt.axes": # noqa: F821
70+
"Computes the distribution for a tensor."
71+
if ax is None:
72+
import matplotlib.pyplot as plt
73+
74+
ax = plt.gca()
75+
ax.cla()
76+
ax.hist(tensor, bins=30, color="orange", alpha=0.7)
77+
ax.set_yscale("log")
78+
return ax

onnx_diagnostic/helpers/doc_helper.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
1-
from typing import Dict, Optional, Tuple
1+
import os
2+
from typing import Dict, List, Optional, Tuple
23
import onnx
34
import onnx.helper as oh
45
import torch
56
from ..reference.torch_ops import OpRunKernel, OpRunTensor
67
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
78
from .ort_session import InferenceSessionForTorch
89

10+
_SAVED: List[str] = []
11+
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
12+
13+
14+
def _get_model_name(op_name: str, provider: str) -> Optional[str]:
15+
if _SAVE_OPTIMIZED_MODEL_:
16+
name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
17+
_SAVED.append(name)
18+
return name
19+
return None
20+
921

1022
class LayerNormalizationOrt(OpRunKernel):
1123
"LayerNormalization with onnxruntime"
1224

1325
@classmethod
1426
def device_dependent(cls) -> bool:
1527
"Needs device."
16-
return False
28+
return True
1729

1830
def __init__(
1931
self,
@@ -70,7 +82,11 @@ def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
7082
)
7183
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
7284
self._provider = provider
73-
return InferenceSessionForTorch(layer_model, providers=[provider])
85+
return InferenceSessionForTorch(
86+
layer_model,
87+
optimized_model_filepath=_get_model_name("layer_norm", provider),
88+
providers=[provider],
89+
)
7490

7591
def run(self, x, scale, bias=None):
7692
itype = torch_dtype_to_onnx_dtype(x.dtype)
@@ -94,7 +110,7 @@ class MatMulOrt(OpRunKernel):
94110
@classmethod
95111
def device_dependent(cls) -> bool:
96112
"Needs device."
97-
return False
113+
return True
98114

99115
def __init__(
100116
self,
@@ -127,7 +143,11 @@ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
127143
)
128144
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
129145
self._provider = provider
130-
return InferenceSessionForTorch(model, providers=[provider])
146+
return InferenceSessionForTorch(
147+
model,
148+
optimized_model_filepath=_get_model_name("matmul", provider),
149+
providers=[provider],
150+
)
131151

132152
def run(self, a, b):
133153
itype = torch_dtype_to_onnx_dtype(a.dtype)

onnx_diagnostic/torch_export_patches/patch_module_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
8080
"AutoformerModel": "AutoformerEncoderLayer",
8181
"BartEncoderLayer": "BartEncoderLayer",
8282
"BartForConditionalGeneration": "BartEncoderLayer",
83+
"BartModel": "BartEncoderLayer",
8384
"BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
8485
"BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
8586
"BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def validate_model(
387387
if model_options:
388388
print(f"[validate_model] model_options={model_options!r}")
389389
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
390+
print(
391+
f"[validate_model] rewrite={rewrite}, patch={patch}, "
392+
f"stop_if_static={stop_if_static}"
393+
)
394+
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
395+
print(f"[validate_model] dump_folder={dump_folder!r}")
390396
summary["model_id"] = model_id
391397
summary["model_subfolder"] = subfolder or ""
392398

@@ -446,6 +452,8 @@ def validate_model(
446452
print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
447453
else:
448454
del data["rewrite"]
455+
if verbose:
456+
print("[validate_model] no rewrite")
449457
if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
450458
print("[validate_model] -- PRINT CONFIG")
451459
print("-- type(config)", type(data["configuration"]))
@@ -1334,13 +1342,13 @@ def call_torch_export_custom(
13341342
"custom-nostrict",
13351343
"custom-nostrict-default",
13361344
"custom-nostrict-all",
1337-
"custom-inline",
1338-
"custom-strict-inline",
1339-
"custom-strict-default-inline",
1340-
"custom-strict-all-inline",
1341-
"custom-nostrict-inline",
1342-
"custom-nostrict-default-inline",
1343-
"custom-nostrict-all-inline",
1345+
"custom-noinline",
1346+
"custom-strict-noinline",
1347+
"custom-strict-default-noinline",
1348+
"custom-strict-all-noinline",
1349+
"custom-nostrict-noinline",
1350+
"custom-nostrict-default-noinline",
1351+
"custom-nostrict-all-noinline",
13441352
}
13451353
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
13461354
assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -1381,10 +1389,7 @@ def call_torch_export_custom(
13811389
),
13821390
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
13831391
)
1384-
inline = "-inline" in exporter
1385-
if inline:
1386-
export_options.aten_as_function = set()
1387-
1392+
inline = "-noinline" not in exporter
13881393
options = OptimizationOptions(patterns=optimization) if optimization else None
13891394
model = data["model"]
13901395
kws = dict(

0 commit comments

Comments
 (0)