Skip to content

Commit 1b20294

Browse files
committed
improve
1 parent 6da835f commit 1b20294

File tree

2 files changed

+171
-59
lines changed

2 files changed

+171
-59
lines changed

_doc/technical/plot_layer_norm_discrepancies.py

Lines changed: 140 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,48 @@
1414
+++++++++
1515
"""
1616

17+
import itertools
1718
import pandas
1819
import onnx
1920
import onnx.helper as oh
2021
import onnxruntime
2122
import torch
2223
from onnx_array_api.plotting.graphviz_helper import plot_dot
24+
from onnx_diagnostic.ext_test_case import unit_test_going
2325
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
26+
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype
27+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
28+
from onnx_diagnostic.helpers.doc_helper import LayerNormalizationOrt, MatMulOrt
2429
from onnx_diagnostic.reference import TorchOnnxEvaluator
2530

31+
TFLOAT = onnx.TensorProto.FLOAT
2632
TFLOAT16 = onnx.TensorProto.FLOAT16
2733

28-
model = oh.make_model(
29-
oh.make_graph(
30-
[
31-
oh.make_node("LayerNormalization", ["X", "scale", "bias"], ["norm"], axis=-1),
32-
oh.make_node("MatMul", ["norm", "weights"], ["mm"]),
33-
oh.make_node("Add", ["mm", "bias2"], ["Z"]),
34-
],
35-
"layer_norm_matmul_add",
36-
[
37-
oh.make_tensor_value_info("X", TFLOAT16, ["a", "b", "c"]),
38-
oh.make_tensor_value_info("scale", TFLOAT16, ["c"]),
39-
oh.make_tensor_value_info("bias", TFLOAT16, ["c"]),
40-
oh.make_tensor_value_info("weights", TFLOAT16, ["c", "c"]),
41-
oh.make_tensor_value_info("bias2", TFLOAT16, ["c"]),
42-
],
43-
[oh.make_tensor_value_info("Z", TFLOAT16, ["a", "b", "c"])],
44-
),
45-
ir_version=9,
46-
opset_imports=[oh.make_opsetid("", 18)],
47-
)
4834

35+
def get_model(itype: int = TFLOAT16):
36+
return oh.make_model(
37+
oh.make_graph(
38+
[
39+
oh.make_node("LayerNormalization", ["X", "scale", "bias"], ["norm"], axis=-1),
40+
oh.make_node("MatMul", ["norm", "weights"], ["mm"]),
41+
oh.make_node("Add", ["mm", "bias2"], ["Z"]),
42+
],
43+
"layer_norm_matmul_add",
44+
[
45+
oh.make_tensor_value_info("X", itype, ["a", "b", "c"]),
46+
oh.make_tensor_value_info("scale", itype, ["c"]),
47+
oh.make_tensor_value_info("bias", itype, ["c"]),
48+
oh.make_tensor_value_info("weights", itype, ["c", "c"]),
49+
oh.make_tensor_value_info("bias2", itype, ["c"]),
50+
],
51+
[oh.make_tensor_value_info("Z", itype, ["a", "b", "c"])],
52+
),
53+
ir_version=9,
54+
opset_imports=[oh.make_opsetid("", 18)],
55+
)
56+
57+
58+
model = get_model()
4959
plot_dot(model)
5060

5161
# %%
@@ -55,50 +65,146 @@
5565
# That will be :epkg:`onnxruntime` and
5666
# :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`.
5767

58-
feeds = {
59-
"X": (torch.rand((32, 1024, 1152), dtype=torch.float16) - 0.5) * 120,
60-
"scale": torch.rand((1152,), dtype=torch.float16),
61-
"bias": torch.rand((1152,), dtype=torch.float16),
62-
"weights": torch.rand((1152, 1152), dtype=torch.float16),
63-
"bias2": torch.rand((1152,), dtype=torch.float16),
64-
}
65-
np_feeds = {k: v.detach().numpy() for k, v in feeds.items()}
66-
kws = dict(with_shape=True, with_min_max=True, with_device=True)
67-
data = []
68+
last_dim = 64 if unit_test_going() else 1152
6869

69-
for provider in ["CPU", "CUDA"]:
70+
71+
def make_feeds(last_dim: int):
72+
return {
73+
"X": (torch.rand((32, 1024, last_dim), dtype=torch.float16) - 0.5) * 120,
74+
"scale": torch.rand((last_dim,), dtype=torch.float16),
75+
"bias": torch.rand((last_dim,), dtype=torch.float16),
76+
"weights": torch.rand((last_dim, last_dim), dtype=torch.float16),
77+
"bias2": torch.rand((last_dim,), dtype=torch.float16),
78+
}
79+
80+
81+
def cast_feeds(itype, provider, feeds):
82+
np_feeds = {k: v.detach().numpy() for k, v in feeds.items()}
7083
if provider == "CUDA":
7184
if not torch.cuda.is_available():
72-
continue
85+
return None, None
7386
tch_feeds = {k: v.to("cuda") for k, v in feeds.items()}
7487
ort_feeds = np_feeds
7588
else:
7689
tch_feeds = feeds.copy()
7790
tch_feeds["X"] = tch_feeds["X"][:2] # too long otherwise
7891
ort_feeds = np_feeds.copy()
7992
ort_feeds["X"] = ort_feeds["X"][:2]
93+
tch_feeds = {k: v.to(ttype) for k, v in tch_feeds.items()}
94+
ort_feeds = {k: v.astype(np_dtype) for k, v in ort_feeds.items()}
95+
return tch_feeds, ort_feeds
96+
97+
98+
feeds = make_feeds(last_dim)
99+
kws = dict(with_shape=True, with_min_max=True, with_device=True)
100+
data = []
101+
baseline = {}
102+
103+
for provider, itype in itertools.product(["CPU", "CUDA"], [TFLOAT, TFLOAT16]):
104+
ttype = onnx_dtype_to_torch_dtype(itype)
105+
np_dtype = onnx_dtype_to_np_dtype(itype)
106+
tch_feeds, ort_feeds = cast_feeds(itype, provider, feeds)
107+
if tch_feeds is None:
108+
continue
109+
110+
model = get_model(itype)
80111
print()
81-
print(f"-- running on {provider}")
112+
print(f"-- running on {provider} with {onnx_dtype_name(itype)}")
82113
print("-- running with torch")
83114
torch_sess = TorchOnnxEvaluator(model, providers=[f"{provider}ExecutionProvider"])
84115
expected = torch_sess.run(None, tch_feeds)
116+
baseline[itype, provider, "torch"] = expected
85117
print(f"-- torch: {string_type(expected, **kws)}")
86118

87119
print("-- running with ort")
88120
ort_sess = onnxruntime.InferenceSession(
89121
model.SerializeToString(), providers=[f"{provider}ExecutionProvider"]
90122
)
91123
got = ort_sess.run(None, ort_feeds)
124+
baseline[itype, provider, "ort"] = got
92125
print(f"-- ort: {string_type(got, **kws)}")
93126
diff = max_diff(expected, got, hist=True)
94127
print(f"-- diff {string_diff(diff)}")
95128

96129
# memorize the data
130+
diff["dtype"] = onnx_dtype_name(itype)
97131
diff["provider"] = provider
98132
diff.update(diff["rep"])
99133
del diff["rep"]
134+
del diff["dnan"]
135+
del diff[">100.0"]
136+
del diff[">10.0"]
100137
data.append(diff)
101138

102139
# %%
103-
df = pandas.DataFrame(data).set_index("provider")
140+
df = pandas.DataFrame(data).set_index(["provider", "dtype"])
104141
print(df)
142+
143+
# %%
144+
# Visually.
145+
146+
df["abs"].plot(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
147+
148+
# %%
149+
# The discrepancies are significant on CUDA, higher for float16.
150+
# Let's see which operator is responsible for them,
151+
# *LayerNormalization* or *MatMul*.
152+
153+
# %%
154+
# The discrepancies come from?
155+
# ++++++++++++++++++++++++++++
156+
#
157+
# We mix torch and onnxruntime to execute the kernels.
158+
159+
data = []
160+
161+
for mod, provider, itype in itertools.product(
162+
["ORT-TORCH", "TORCH-ORT"], ["CPU", "CUDA"], [TFLOAT, TFLOAT16]
163+
):
164+
ttype = onnx_dtype_to_torch_dtype(itype)
165+
np_dtype = onnx_dtype_to_np_dtype(itype)
166+
tch_feeds, _ = cast_feeds(itype, provider, feeds)
167+
if tch_feeds is None:
168+
continue
169+
170+
custom_kernels = (
171+
{("", "LayerNormalization"): LayerNormalizationOrt}
172+
if mod == "ORT-TORCH"
173+
else {("", "MatMul"): MatMulOrt}
174+
)
175+
176+
model = get_model(itype)
177+
print()
178+
print(f"-- {mod} running on {provider} with {onnx_dtype_name(itype)}")
179+
sess = TorchOnnxEvaluator(
180+
model,
181+
custom_kernels=custom_kernels,
182+
providers=[f"{provider}ExecutionProvider"],
183+
)
184+
got = sess.run(None, tch_feeds)
185+
print(f"-- {mod}: {string_type(got, **kws)}")
186+
187+
difft = max_diff(baseline[itype, provider, "torch"], got)
188+
print(f"-- diff with torch {string_diff(difft)}")
189+
diffo = max_diff(baseline[itype, provider, "ort"], got)
190+
print(f"-- diff with ort {string_diff(diffo)}")
191+
192+
data.append(
193+
dict(
194+
model=mod,
195+
dtype=onnx_dtype_name(itype),
196+
provider=provider,
197+
diff_ort=diffo["abs"],
198+
diff_torch=difft["abs"],
199+
)
200+
)
201+
202+
# %%
203+
df = pandas.DataFrame(data).set_index(["model", "provider", "dtype"])
204+
df = df.sort_index()
205+
print(df)
206+
207+
# %%
208+
# Visually.
209+
210+
df[["diff_ort", "diff_torch"]].plot(title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B")

onnx_diagnostic/helpers/doc_helper.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import onnx
33
import onnx.helper as oh
44
import torch
5-
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
65
from ..reference.torch_ops import OpRunKernel, OpRunTensor
6+
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
7+
from .ort_session import InferenceSessionForTorch
78

89

910
class LayerNormalizationOrt(OpRunKernel):
@@ -36,50 +37,55 @@ def __init__(
3637
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
3738
self.is_cpu = torch.device("cpu") == self.device
3839

39-
def _make_model(self, itype: int, rank: int) -> onnx.ModelProto:
40+
def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
4041
shape = [*["d{i}" for i in range(rank - 1)], "last"]
4142
layer_model = oh.make_model(
4243
oh.make_graph(
4344
[
4445
oh.make_node(
4546
"LayerNormalization",
46-
["X", "W", "B"],
47+
["X", "W", "B"] if has_bias else ["X", "W"],
4748
["Z"],
4849
axis=self.axis,
4950
epsilon=self.epsilon,
5051
)
5152
],
5253
"dummy",
53-
[
54-
oh.make_tensor_value_info("X", itype, shape),
55-
oh.make_tensor_value_info("W", itype, ["last"]),
56-
oh.make_tensor_value_info("B", itype, ["last"]),
57-
],
54+
(
55+
[
56+
oh.make_tensor_value_info("X", itype, shape),
57+
oh.make_tensor_value_info("W", itype, ["last"]),
58+
oh.make_tensor_value_info("B", itype, ["last"]),
59+
]
60+
if has_bias
61+
else [
62+
oh.make_tensor_value_info("X", itype, shape),
63+
oh.make_tensor_value_info("W", itype, ["last"]),
64+
]
65+
),
5866
[oh.make_tensor_value_info("Z", itype, shape)],
5967
),
6068
ir_version=9,
6169
opset_imports=[oh.make_opsetid("", 18)],
6270
)
63-
import onnxruntime
64-
6571
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
66-
return onnxruntime.InferenceSession(
67-
layer_model.SerializeToString(), providers=[provider]
68-
)
72+
self._provider = provider
73+
return InferenceSessionForTorch(layer_model, providers=[provider])
6974

7075
def run(self, x, scale, bias=None):
7176
itype = torch_dtype_to_onnx_dtype(x.dtype)
7277
rank = len(x.shape)
7378
key = itype, rank
7479
if key not in self._cache:
75-
self._cache[key] = self._make_model(itype, rank)
80+
self._cache[key] = self._make_model(itype, rank, bias is not None)
7681
sess = self._cache[key]
77-
feeds = dict(X=x, W=scale)
82+
if self.verbose:
83+
print(f"[LayerNormalizationOrt] running on {self._provider!r}")
84+
feeds = dict(X=x.tensor, W=scale.tensor)
7885
if bias is not None:
79-
feeds["B"] = bias
80-
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
86+
feeds["B"] = bias.tensor
8187
got = sess.run(None, feeds)[0]
82-
return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
88+
return OpRunTensor(got)
8389

8490

8591
class MatMulOrt(OpRunKernel):
@@ -117,12 +123,11 @@ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
117123
[oh.make_tensor_value_info("C", itype, shapec)],
118124
),
119125
ir_version=9,
120-
opset_imports=[oh.make_opsetid("", 17)],
126+
opset_imports=[oh.make_opsetid("", 18)],
121127
)
122-
import onnxruntime
123-
124128
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
125-
return onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider])
129+
self._provider = provider
130+
return InferenceSessionForTorch(model, providers=[provider])
126131

127132
def run(self, a, b):
128133
itype = torch_dtype_to_onnx_dtype(a.dtype)
@@ -131,7 +136,8 @@ def run(self, a, b):
131136
if key not in self._cache:
132137
self._cache[key] = self._make_model(itype, ranka, rankb)
133138
sess = self._cache[key]
134-
feeds = dict(A=a, B=b)
135-
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
139+
if self.verbose:
140+
print(f"[MatMulOrt] running on {self._provider!r}")
141+
feeds = dict(A=a.tensor, B=b.tensor)
136142
got = sess.run(None, feeds)[0]
137-
return OpRunTensor(torch.from_numpy(got).to(a.dtype).to(a.device))
143+
return OpRunTensor(got)

0 commit comments

Comments
 (0)