Skip to content

Commit ac1c871

Browse files
committed
documentation
1 parent 3ad5180 commit ac1c871

File tree

7 files changed

+51
-7
lines changed

7 files changed

+51
-7
lines changed

_doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158
"ignore_repr_types": "matplotlib\\.(text|axes)",
159159
# robubstness
160160
"reset_modules_order": "both",
161-
"reset_modules": ("matplotlib", "onnx_diagnostic.reset_torch_transformers"),
161+
"reset_modules": ("matplotlib", "onnx_diagnostic.doc.reset_torch_transformers"),
162162
}
163163

164164
if int(os.environ.get("UNITTEST_GOING", "0")):

_doc/examples/plot_export_cond.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
import torch
16+
from onnx_diagnostic import doc
1617

1718

1819
# %%
@@ -84,3 +85,8 @@ def neg(x):
8485

8586
ep = torch.export.export(model, (x,))
8687
print(ep.graph)
88+
89+
90+
# %%
91+
92+
doc.plot_legend("If -> torch.cond", "torch.export.export", "tomato")

_doc/examples/plot_export_locate_issue.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
5555
# %%
5656
# The error shows:
5757
#
58-
# ::
58+
# .. code-block::
59+
#
5960
# Constraints violated (L['args'][0][0].size()[0])!
6061
# For more information, run with TORCH_LOGS="+dynamic".
6162
# - Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
@@ -68,7 +69,7 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
6869
# One way to find the exact location is to retrieve a stack trace
6970
# by inserting an assert such as the following:
7071
#
71-
# ::
72+
# .. code-block::
7273
#
7374
# assert msg != "range_refined_to_singleton", (
7475
# f"A dynamic dimension becomes static! "
@@ -94,7 +95,7 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
9495
# is the following one. It points out the line turing a dynamic dimension into
9596
# static.
9697
#
97-
# ::
98+
# .. code-block::
9899
#
99100
# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
100101
# z = x * caty

_doc/examples/plot_export_tiny_llm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def _forward_(*args, _f=None, **kwargs):
8585
#
8686
# Let's create an untrained model using the config file provided
8787
# `config.json <https://huggingface.co/arnir0/Tiny-LLM/blob/main/config.json>`_
88-
# to create an untrained model: :func:`....get_tiny_llm`.
88+
# to create an untrained model:
89+
# :func:`onnx_diagnostic.torch_models.llms.get_tiny_llm`.
8990
# Then let's use it.
9091

9192
experiment = get_tiny_llm()
@@ -139,7 +140,7 @@ def _forward_(*args, _f=None, **kwargs):
139140
#
140141
# Let's use the same dummy inputs but we use the downloaded model.
141142
# Dummy inputs and dynamic shapes are created by function
142-
# :func:`....get_tiny_llm`.
143+
# :func:`onnx_diagnostic.torch_models.llms.get_tiny_llm`.
143144

144145
data = get_tiny_llm()
145146
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

_doc/examples/plot_export_with_dynamic_shapes_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def forward(self, x, y, z):
5757
},
5858
)
5959
print(ep)
60-
raise AssertionError("able to export this moel, please update the tutorial")
60+
raise AssertionError("able to export this model, please update the tutorial")
6161
except torch._dynamo.exc.UserError as e:
6262
print(f"unable to use Dim('dz') because {type(e)}, {e}")
6363

onnx_diagnostic/doc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
def reset_torch_transformers(gallery_conf, fname):
2+
"Resets torch dynamo for :epkg:`sphinx-gallery`."
3+
import matplotlib.pyplot as plt
4+
import torch
5+
6+
plt.style.use("ggplot")
7+
torch._dynamo.reset()
8+
9+
10+
def plot_legend(
11+
text: str, text_bottom: str = "", color: str = "green", fontsize: int = 35
12+
) -> "matplotlib.axes.Axes": # noqa: F821
13+
import matplotlib.pyplot as plt
14+
15+
fig = plt.figure()
16+
ax = fig.add_subplot()
17+
ax.axis([0, 5, 0, 5])
18+
ax.text(2.5, 4, "END", fontsize=50, horizontalalignment="center")
19+
ax.text(
20+
2.5,
21+
2.5,
22+
text,
23+
fontsize=fontsize,
24+
bbox={"facecolor": color, "alpha": 0.5, "pad": 10},
25+
horizontalalignment="center",
26+
verticalalignment="center",
27+
)
28+
if text_bottom:
29+
ax.text(4.5, 0.5, text_bottom, fontsize=20, horizontalalignment="right")
30+
ax.grid(False)
31+
ax.set_axis_off()
32+
return ax

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ exclude = [
1414
"^dist", # skips dist
1515
]
1616

17+
[[tool.mypy.overrides]]
18+
module = ["onnx_diagnostic.doc"]
19+
disable_error_code = ["call-overload", "name-defined"]
20+
1721
[[tool.mypy.overrides]]
1822
module = ["onnx_diagnostic.args"]
1923
disable_error_code = ["arg-type", "call-overload", "index"]

0 commit comments

Comments
 (0)