Skip to content

Commit 8a78172

Browse files
committed
better graphs
1 parent 97eb9e5 commit 8a78172

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

_doc/technical/plot_layer_norm_discrepancies.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import onnxruntime
2222
import torch
2323
from onnx_array_api.plotting.graphviz_helper import plot_dot
24+
from onnx_diagnostic.doc import rotate_align, save_fig
2425
from onnx_diagnostic.ext_test_case import unit_test_going
2526
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
2627
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype
@@ -143,7 +144,12 @@ def cast_feeds(itype, provider, feeds):
143144
# %%
144145
# Visually.
145146

146-
df["abs"].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
147+
save_fig(
148+
rotate_align(
149+
df[["abs"]].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
150+
),
151+
"plot_layer_norm_discrepancies_1.png",
152+
)
147153

148154
# %%
149155
# The discrepancies are significant on CUDA, higher for float16.
@@ -207,6 +213,11 @@ def cast_feeds(itype, provider, feeds):
207213
# %%
208214
# Visually.
209215

210-
df[["diff_ort", "diff_torch"]].plot.bar(
211-
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
216+
save_fig(
217+
rotate_align(
218+
df[["diff_ort", "diff_torch"]].plot.bar(
219+
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
220+
)
221+
),
222+
"plot_layer_norm_discrepancies_2.png",
212223
)

_doc/technical/plot_parallelized_reduction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def make_value(base, value):
188188
# Visually.
189189

190190
ax = df.plot.bar(logy=True)
191+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
191192
fig = ax.get_figure()
192193
fig.savefig("plot_parallelized_reduction.png")
193194

onnx_diagnostic/doc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,21 @@ def plot_legend(
3030
ax.grid(False)
3131
ax.set_axis_off()
3232
return ax
33+
34+
35+
def rotate_align(ax, angle=15, align="right"):
36+
"""Rotates x-label and align them to thr right. Returns ax."""
37+
for label in ax.get_xticklabels():
38+
label.set_rotation(angle)
39+
label.set_horizontalalignment(align)
40+
return ax
41+
42+
43+
def save_fig(ax, name: str):
44+
"""Applies ``tight_layout`` and saves the figures. Returns ax."""
45+
import matplotlib.pyplot as plt
46+
47+
plt.tight_layout()
48+
fig = ax.get_figure()
49+
fig.savefig(name)
50+
return ax

0 commit comments

Comments
 (0)