Skip to content

Commit 4f6603f

Browse files
committed
update script
1 parent 7747de3 commit 4f6603f

File tree

1 file changed

+52
-22
lines changed

1 file changed

+52
-22
lines changed

_doc/technical/plot_gemm_or_matmul_add.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def matrix_diff(tensors):
145145
# a lot higher.
146146

147147
B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384
148-
labels = ["linear", *[o.name for o in model.graph.output], "a @ x + b"]
148+
labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"]
149149
all_results = {}
150150

151151
for itype, dtype, device in [
@@ -187,28 +187,58 @@ def matrix_diff(tensors):
187187
# bias value vs discrepancies
188188
# ===========================
189189
#
190-
# Let's compare GemmOnly (so bias is included) and Gemm+Add.
191-
192-
i, j = 1, -1
193-
labs = labels[i], labels[j]
194-
195-
fig, ax = plt.subplots(len(all_results), 2, figsize=(8, 2.5 * len(results)))
196-
for pos, ((device, dtype), results) in enumerate(all_results.items()):
197-
m1, m2 = results[i], results[j]
198-
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
199-
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
200-
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
201-
ax[pos, 0].plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), ".")
202-
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}")
203-
204-
corr = matrix_diff(results)
205-
ax[pos, 1].imshow(corr, cmap="Blues", vmin=0, vmax=corr.max())
206-
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
207-
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45)
208-
ax[pos, 1].set_yticks(range(len(labels)), labels)
209-
ax[pos, 1].set_title(f"max={diff.max()}")
190+
# Let's compare torch linear with GemmOnly.
191+
192+
193+
def make_figure_axis(all_results, i, j):
194+
labs = labels[i], labels[j]
195+
fig, ax = plt.subplots(len(all_results), 2, figsize=(12, 4 * len(all_results)))
196+
for pos, ((device, dtype), results) in enumerate(all_results.items()):
197+
m1, m2 = results[i], results[j]
198+
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
199+
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
200+
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
201+
ax[pos, 0].plot(
202+
B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), "."
203+
)
204+
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10)
205+
206+
corr = matrix_diff(results)
207+
ax[pos, 1].imshow(corr, cmap="Wistia", vmin=0, vmax=corr.max())
208+
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
209+
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45, ha="right", fontsize=10)
210+
ax[pos, 1].set_yticks(range(len(labels)), labels, fontsize=10)
211+
ax[pos, 1].set_title(f"max={diff.max():1.2g}", fontsize=10)
212+
for _i in range(corr.shape[0]):
213+
for _j in range(corr.shape[1]):
214+
ax[pos, 1].text(
215+
_j,
216+
_i,
217+
f"{corr[_i, _j]:1.1g}",
218+
ha="center",
219+
va="center",
220+
color="black",
221+
fontsize=8,
222+
)
223+
fig.suptitle(
224+
f"Left column: discrepancies {labs[0]} VS {labs[1]}\n"
225+
f"Right column: max absolute error, accross all configuration\n"
226+
f"white is good, orange is not"
227+
)
228+
return fig, ax
229+
230+
231+
fig, ax = make_figure_axis(all_results, 0, 1)
232+
fig.tight_layout()
233+
fig.savefig("plot_gemm_or_matmul_add1.png")
234+
235+
# %%
236+
# Let's compare with ``a @ x + b``.
237+
238+
fig, ax = make_figure_axis(all_results, -1, 1)
210239
fig.tight_layout()
211-
fig.savefig("plot_gemm_or_matmul_add.png")
240+
fig.savefig("plot_gemm_or_matmul_add2.png")
241+
212242

213243
# %%
214244
# Discrepancies do not happen all the time but it is very likely to happen.

0 commit comments

Comments
 (0)