Skip to content

Commit cddd3cf

Browse files
committed
add an example
1 parent 0e6de76 commit cddd3cf

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

_doc/technical/plot_gemm_or_matmul_add.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
4141
return oh.make_model(
4242
oh.make_graph(
4343
[
44-
oh.make_node("Gemm", ["A", "X", "B"], ["Ygemmfused"]),
44+
oh.make_node("Gemm", ["A", "X", "B"], ["GemmOnly"]),
4545
oh.make_node("Gemm", ["A", "X"], ["gmm"]),
46-
oh.make_node("Add", ["gmm", "B"], ["Ygemm"]),
46+
oh.make_node("Add", ["gmm", "B"], ["GemmAdd"]),
4747
oh.make_node("MatMul", ["A", "X"], ["mm"]),
48-
oh.make_node("Add", ["mm", "B"], ["Ymm"]),
48+
oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]),
4949
oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"),
50-
oh.make_node("Add", ["fmm", "B"], ["Yfused"]),
50+
oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]),
5151
],
5252
"test",
5353
[
@@ -56,10 +56,10 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
5656
oh.make_tensor_value_info("B", itype, ["c"]),
5757
],
5858
[
59-
oh.make_tensor_value_info("Ygemmfused", itype, ["a", "c"]),
60-
oh.make_tensor_value_info("Yfused", itype, ["a", "c"]),
61-
oh.make_tensor_value_info("Ygemm", itype, ["a", "c"]),
62-
oh.make_tensor_value_info("Ymm", itype, ["a", "c"]),
59+
oh.make_tensor_value_info("GemmOnly", itype, ["a", "c"]),
60+
oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]),
61+
oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]),
62+
oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]),
6363
],
6464
),
6565
opset_imports=[oh.make_opsetid("", 22)],
@@ -140,13 +140,17 @@ def matrix_diff(tensors):
140140
# a lot higher.
141141

142142
B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384
143-
labels = ["torch", *[o.name for o in model.graph.output]]
143+
labels = ["linear", *[o.name for o in model.graph.output], "a @ x + b"]
144+
all_results = {}
144145

145146
for itype, dtype, device in [
146147
(onnx.TensorProto.FLOAT, torch.float32, "cpu"),
147148
(onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
149+
# missing implementation in onnxruntime
150+
# (onnx.TensorProto.BFLOAT16, torch.bfloat16, "cpu"),
148151
(onnx.TensorProto.FLOAT, torch.float32, "cuda"),
149152
(onnx.TensorProto.FLOAT16, torch.float16, "cuda"),
153+
(onnx.TensorProto.BFLOAT16, torch.bfloat16, "cuda"),
150154
]:
151155
if device == "cuda" and not torch.cuda.is_available():
152156
continue
@@ -163,8 +167,9 @@ def matrix_diff(tensors):
163167
graph_optimization_level=GraphOptimizationLevel.ORT_DISABLE_ALL,
164168
optimized_model_filepath=filename,
165169
)
170+
results = [torch.nn.functional.linear(a, x.T, b), *sess.run(None, feeds), a @ x + b]
171+
all_results[device, dtype] = results
166172
has_cast = "Cast" in [n.op_type for n in onnx.load(filename).graph.node]
167-
results = [a @ x + b, *sess.run(None, feeds)]
168173
diffs = matrix_diff(results)
169174
df = pandas.DataFrame(diffs, columns=labels, index=labels)
170175
print(f"------ has_cast={has_cast}, dtype={dtype}, device={device!r}, max(b)={b.max()}")
@@ -176,18 +181,32 @@ def matrix_diff(tensors):
176181
#
177182
# bias value vs discrepancies
178183
# ===========================
179-
180-
181-
m1, m2 = results[0:2]
182-
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
183-
print(f"max(diff)={diff.max()}")
184-
185-
fig, ax = plt.subplots(1, 1, figsize=(5, 3))
186-
ax.plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * 0.5).tolist(), ".")
187-
ax.set_title("Discrepancies (y) VS Bias (x)")
184+
#
185+
# Let's compare GemmOnly (so bias is included) and Gemm+Add.
186+
187+
i, j = 1, -1
188+
labs = labels[i], labels[j]
189+
190+
fig, ax = plt.subplots(len(all_results), 2, figsize=(8, 2.5 * len(results)))
191+
for pos, ((device, dtype), results) in enumerate(all_results.items()):
192+
m1, m2 = results[i], results[j]
193+
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
194+
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
195+
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
196+
ax[pos, 0].plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), ".")
197+
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}")
198+
199+
corr = matrix_diff(results)
200+
ax[pos, 1].imshow(corr, cmap="Blues", vmin=0, vmax=corr.max())
201+
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
202+
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45)
203+
ax[pos, 1].set_yticks(range(len(labels)), labels)
204+
ax[pos, 1].set_title(f"max={diff.max()}")
205+
fig.tight_layout()
188206
fig.savefig("plot_gemm_or_matmul_add.png")
189207

190208
# %%
191209
# Discrepancies do not happen all the time but it is very likely to happen.
192-
# Fused Gemm should be avoided when the bias is very different from the multiplied
193-
# matrix and avoided in the generic case.
210+
# The use of Gemm with a bias not null should be used when torch is doing
211+
# the same and it seems to depend on the type as well.
212+
# The difference is even higher for bfloat16.

0 commit comments

Comments
 (0)