Skip to content

Commit f1cdf11

Browse files
committed
add one example
1 parent 97cc48c commit f1cdf11

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
LayerNormalization implementation cannot be exchanged
3+
=====================================================
4+
5+
This example applies what was illustrated
6+
:ref:`l-plot-parallelized-reduction`, reduction operations
7+
are sensitive to parallelization.
8+
9+
We consider a small model including a layer normalization
10+
followed by a matrix multiplication and we show that replacing
11+
a kernel by another one may significantly impact the output.
12+
13+
The model
14+
+++++++++
15+
"""
16+
17+
import pandas
18+
import onnx
19+
import onnx.helper as oh
20+
import onnxruntime
21+
import torch
22+
from onnx_array_api.plotting.graphviz_helper import plot_dot
23+
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
24+
from onnx_diagnostic.reference import TorchOnnxEvaluator
25+
26+
TFLOAT16 = onnx.TensorProto.FLOAT16
27+
28+
model = oh.make_model(
29+
oh.make_graph(
30+
[
31+
oh.make_node("LayerNormalization", ["X", "scale", "bias"], ["norm"], axis=-1),
32+
oh.make_node("MatMul", ["norm", "weights"], ["mm"]),
33+
oh.make_node("Add", ["mm", "bias2"], ["Z"]),
34+
],
35+
"layer_norm_matmul_add",
36+
[
37+
oh.make_tensor_value_info("X", TFLOAT16, ["a", "b", "c"]),
38+
oh.make_tensor_value_info("scale", TFLOAT16, ["c"]),
39+
oh.make_tensor_value_info("bias", TFLOAT16, ["c"]),
40+
oh.make_tensor_value_info("weights", TFLOAT16, ["c", "c"]),
41+
oh.make_tensor_value_info("bias2", TFLOAT16, ["c"]),
42+
],
43+
[oh.make_tensor_value_info("Z", TFLOAT16, ["a", "b", "c"])],
44+
),
45+
ir_version=9,
46+
opset_imports=[oh.make_opsetid("", 18)],
47+
)
48+
49+
plot_dot(model)
50+
51+
# %%
52+
# Let's compare two runtimes
53+
# ++++++++++++++++++++++++++
54+
#
55+
# That will be :epkg:`onnxruntime` and
56+
# :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`.
57+
58+
feeds = {
59+
"X": (torch.rand((32, 1024, 1152), dtype=torch.float16) - 0.5) * 120,
60+
"scale": torch.rand((1152,), dtype=torch.float16),
61+
"bias": torch.rand((1152,), dtype=torch.float16),
62+
"weights": torch.rand((1152, 1152), dtype=torch.float16),
63+
"bias2": torch.rand((1152,), dtype=torch.float16),
64+
}
65+
np_feeds = {k: v.detach().numpy() for k, v in feeds.items()}
66+
kws = dict(with_shape=True, with_min_max=True, with_device=True)
67+
data = []
68+
69+
for provider in ["CPU", "CUDA"]:
70+
if provider == "CUDA":
71+
if not torch.cuda.is_available():
72+
continue
73+
tch_feeds = {k: v.to("cuda") for k, v in feeds.items()}
74+
ort_feeds = np_feeds
75+
else:
76+
tch_feeds = feeds.copy()
77+
tch_feeds["X"] = tch_feeds["X"][:2] # too long otherwise
78+
ort_feeds = np_feeds.copy()
79+
ort_feeds["X"] = ort_feeds["X"][:2]
80+
print()
81+
print(f"-- running on {provider}")
82+
print("-- running with torch")
83+
torch_sess = TorchOnnxEvaluator(model, providers=[f"{provider}ExecutionProvider"])
84+
expected = torch_sess.run(None, tch_feeds)
85+
print(f"-- torch: {string_type(expected, **kws)}")
86+
87+
print("-- running with ort")
88+
ort_sess = onnxruntime.InferenceSession(
89+
model.SerializeToString(), providers=[f"{provider}ExecutionProvider"]
90+
)
91+
got = ort_sess.run(None, ort_feeds)
92+
print(f"-- ort: {string_type(got, **kws)}")
93+
diff = max_diff(expected, got, hist=True)
94+
print(f"-- diff {string_diff(diff)}")
95+
96+
# memorize the data
97+
diff["provider"] = provider
98+
diff.update(diff["rep"])
99+
del diff["rep"]
100+
data.append(diff)
101+
102+
# %%
103+
df = pandas.DataFrame(data).set_index("provider")
104+
print(df)

_doc/technical/plot_parallelized_reduction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""
2+
.. _l-plot-parallelized-reduction:
3+
24
Reproducible Parallelized Reduction is difficult
35
================================================
46

0 commit comments

Comments
 (0)