Skip to content

Commit f1f92f5

Browse files
authored
documentation (#137)
* documentation * fix graph
1 parent 625b028 commit f1f92f5

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

_doc/technical/plot_layer_norm_discrepancies.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def cast_feeds(itype, provider, feeds):
143143
# %%
144144
# Visually.
145145

146-
df["abs"].plot(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
146+
df["abs"].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
147147

148148
# %%
149149
# The discrepancies are significant on CUDA, higher for float16.
@@ -207,4 +207,6 @@ def cast_feeds(itype, provider, feeds):
207207
# %%
208208
# Visually.
209209

210-
df[["diff_ort", "diff_torch"]].plot(title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B")
210+
df[["diff_ort", "diff_torch"]].plot.bar(
211+
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
212+
)

onnx_diagnostic/helpers/doc_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
node: onnx.NodeProto,
2121
version=None,
2222
device: Optional[torch.device] = None,
23-
verbose=0,
23+
verbose: int = 0,
2424
):
2525
super().__init__(node, version, verbose=verbose)
2626
self.axis = self.get_attribute_int(node, "axis", -1)
@@ -101,7 +101,7 @@ def __init__(
101101
node: onnx.NodeProto,
102102
version=None,
103103
device: Optional[torch.device] = None,
104-
verbose=0,
104+
verbose: int = 0,
105105
):
106106
super().__init__(node, version, verbose=verbose)
107107
self.device = device

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ class TorchOnnxEvaluator:
168168
class LayerNormalizationOrt(OpRunKernel):
169169
"LayerNormalization based on onnxruntime"
170170
171-
def __init__(self, node: onnx.NodeProto, version=None):
172-
super().__init__(node, version)
171+
def __init__(self, node: onnx.NodeProto, version=None, verbose=0):
172+
super().__init__(node, version, verbose=verbose)
173173
self.axis = self.get_attribute_int(node, "axis", -1)
174174
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
175175
self.stash_type = onnx_dtype_to_torch_dtype(

0 commit comments

Comments
 (0)