Skip to content

Commit 179c9cf

Browse files
committed
fix
1 parent 4f6603f commit 179c9cf

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

_doc/technical/plot_gemm_or_matmul_add.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
5353
oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]),
5454
oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"),
5555
oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]),
56+
oh.make_node("Cast", ["A"], ["Afloat"], to=onnx.TensorProto.FLOAT),
57+
oh.make_node("Cast", ["B"], ["Bfloat"], to=onnx.TensorProto.FLOAT),
58+
oh.make_node("Cast", ["X"], ["Xfloat"], to=onnx.TensorProto.FLOAT),
59+
oh.make_node("Gemm", ["Afloat", "Xfloat"], ["gmmfloat"]),
60+
oh.make_node("Add", ["gmmfloat", "Bfloat"], ["gemmaddfloat"]),
61+
oh.make_node("Cast", ["gemmaddfloat"], ["CastGemmAddCast"], to=itype),
62+
oh.make_node("Gemm", ["Afloat", "Xfloat", "Bfloat"], ["GemmOnlyfloat"]),
63+
oh.make_node("Cast", ["GemmOnlyfloat"], ["CastGemmOnlyCast"], to=itype),
5664
],
5765
"test",
5866
[
@@ -65,6 +73,8 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
6573
oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]),
6674
oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]),
6775
oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]),
76+
oh.make_tensor_value_info("CastGemmAddCast", itype, ["a", "c"]),
77+
oh.make_tensor_value_info("CastGemmOnlyCast", itype, ["a", "c"]),
6878
],
6979
),
7080
opset_imports=[oh.make_opsetid("", 22)],
@@ -85,7 +95,7 @@ def matrix_diff(tensors):
8595
dtype = np.float16
8696
model = make_model_gemm(itype)
8797

88-
A = np.random.randn(512, 256).astype(dtype)
98+
A = np.random.randn(1280, 256).astype(dtype)
8999
X = np.random.randn(256, 256).astype(dtype)
90100
B = np.random.randn(256).astype(dtype)
91101
feeds = dict(A=A, X=X, B=B)
@@ -112,9 +122,9 @@ def matrix_diff(tensors):
112122
# %%
113123
# Let's try with CUDA and float32 if it is available.
114124

115-
A = torch.randn((512, 512), dtype=torch.float32)
116-
X = torch.randn((512, 512), dtype=torch.float32)
117-
B = torch.randn((512), dtype=torch.float32)
125+
A = torch.randn((1280, 1280), dtype=torch.float32)
126+
X = torch.randn((1280, 1280), dtype=torch.float32)
127+
B = torch.randn((1280), dtype=torch.float32)
118128

119129
for itype, dtype, device in [
120130
(onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
@@ -144,7 +154,9 @@ def matrix_diff(tensors):
144154
# are similar to the others coefficients. What if we make them
145155
# a lot higher.
146156

147-
B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384
157+
A = A / A.max()
158+
X = X / X.max()
159+
B = (torch.arange(1280, dtype=torch.float32) + 1) / 1280 * 16
148160
labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"]
149161
all_results = {}
150162

@@ -199,7 +211,7 @@ def make_figure_axis(all_results, i, j):
199211
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
200212
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
201213
ax[pos, 0].plot(
202-
B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), "."
214+
B.tolist(), (diff.detach().cpu() + torch.rand(1280) * expand).tolist(), "."
203215
)
204216
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10)
205217

0 commit comments

Comments
 (0)