Skip to content

Commit d7e5976

Browse files
authored
Add recipes (#44)
* add recipes * stacklevel * fix issue
1 parent 84817ab commit d7e5976

File tree

10 files changed

+301
-1
lines changed

10 files changed

+301
-1
lines changed

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,12 @@
147147
# path to your examples scripts
148148
"examples_dirs": [
149149
os.path.join(os.path.dirname(__file__), "examples"),
150+
os.path.join(os.path.dirname(__file__), "recipes"),
150151
],
151152
# path where to save gallery generated examples
152153
"gallery_dirs": [
153154
"auto_examples",
155+
"auto_recipes",
154156
],
155157
# no parallelization to avoid conflict with environment variables
156158
"parallel": 1,

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Source are `sdpython/onnx-diagnostic
3131
api/index
3232
cmds/index
3333
auto_examples/index
34+
auto_recipes/index
3435

3536
.. toctree::
3637
:maxdepth: 1

_doc/recipes/README.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Common Export Issues
2+
====================
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Half certain nonzero
3+
====================
4+
5+
:func:`torch.nonzero` returns the indices or the first zero found
6+
in a tensor. The output shape is unknown in the generic case
7+
but... If you have a 2D tensor with at least a nonzero value
8+
in every row, you can guess the dimension. But :func:`torch.export.export`
9+
does not know what you know.
10+
11+
12+
A Model
13+
+++++++
14+
"""
15+
16+
import torch
17+
from onnx_diagnostic import doc
18+
19+
20+
class Model(torch.nn.Module):
21+
def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=0):
22+
chunk_start_idx = torch.Tensor(chunk_start_idx).long()
23+
start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0)
24+
end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0)
25+
seq_range = torch.arange(0, x_len).unsqueeze(-1)
26+
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
27+
seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
28+
idx_left = idx - left_window
29+
idx_left[idx_left < 0] = 0
30+
boundary_left = start_pad[idx_left]
31+
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
32+
idx_right = idx + right_window
33+
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
34+
boundary_right = end_pad[idx_right]
35+
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
36+
return mask_left & mask_right
37+
38+
def forward(self, x):
39+
return self.adaptive_enc_mask(x.shape[1], [])
40+
41+
42+
model = Model()
43+
x = torch.rand((5, 8))
44+
y = model(x)
45+
print(f"x.shape={x.shape}, y.shape={y.shape}")
46+
47+
# %%
48+
# Export
49+
# ++++++
50+
51+
DYN = torch.export.Dim.DYNAMIC
52+
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
53+
print(ep)
54+
55+
56+
# %%
57+
# We can see the following line in the exported program.
58+
# It tells what it cannot verify.
59+
# ``torch.ops.aten._assert_scalar.default(eq,``
60+
# ``"Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");``
61+
62+
63+
# %%
64+
doc.plot_legend("dynamic shapes\nnonzero", "dynamic shapes", "yellow")
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Do not use python int with dynamic shape
3+
=========================================
4+
5+
:func:`torch.export.export` uses :class:`torch.SymInt` to operate on shapes and
6+
optimizes the graph it produces. It checks if two tensors share the same dimension,
7+
if the shapes can be broadcast, ... To do that, python types must not be used
8+
or the algorithm looses information.
9+
10+
Wrong Model
11+
+++++++++++
12+
"""
13+
14+
import math
15+
import torch
16+
from onnx_diagnostic import doc
17+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
18+
19+
20+
class Model(torch.nn.Module):
21+
def dim(self, i, divisor):
22+
return int(math.ceil(i / divisor)) # noqa: RUF046
23+
24+
def forward(self, x):
25+
new_shape = (self.dim(x.shape[0], 8), x.shape[1])
26+
return torch.zeros(new_shape)
27+
28+
29+
model = Model()
30+
x = torch.rand((10, 15))
31+
y = model(x)
32+
print(f"x.shape={x.shape}, y.shape={y.shape}")
33+
34+
# %%
35+
# Export
36+
# ++++++
37+
38+
DYN = torch.export.Dim.DYNAMIC
39+
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
40+
print(ep)
41+
42+
# %%
43+
# The last dimension became static. We must not use int.
44+
# :func:`math.ceil` should be avoided as well since it is a python operation.
45+
# The exporter may fail to detect it is operating on shapes.
46+
#
47+
# Rewrite
48+
# +++++++
49+
50+
51+
class RewrittenModel(torch.nn.Module):
52+
def dim(self, i, divisor):
53+
return (i + divisor - 1) // divisor
54+
55+
def forward(self, x):
56+
new_shape = (self.dim(x.shape[0], 8), x.shape[1])
57+
return torch.zeros(new_shape)
58+
59+
60+
rewritten_model = RewrittenModel()
61+
y = rewritten_model(x)
62+
print(f"x.shape={x.shape}, y.shape={y.shape}")
63+
64+
# %%
65+
# Export
66+
# ++++++
67+
68+
ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=((DYN, DYN),))
69+
print(ep)
70+
71+
72+
# %%
73+
# Find the error
74+
# ++++++++++++++
75+
#
76+
# Function :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
77+
# has a parameter ``stop_if_static`` which patches torch to raise exception
78+
# when something like that is happening.
79+
80+
81+
with bypass_export_some_errors(stop_if_static=True):
82+
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
83+
print(ep)
84+
85+
# %%
86+
doc.plot_legend("dynamic shapes\ndo not cast to\npython int", "dynamic shapes", "yellow")
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import unittest
2+
import os
3+
import sys
4+
import importlib.util
5+
import subprocess
6+
import time
7+
from onnx_diagnostic import __file__ as onnx_diagnostic_file
8+
from onnx_diagnostic.ext_test_case import (
9+
ExtTestCase,
10+
is_windows,
11+
has_torch,
12+
ignore_errors,
13+
)
14+
15+
16+
VERBOSE = 0
17+
ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_diagnostic_file, "..", "..")))
18+
19+
20+
def import_source(module_file_path, module_name):
21+
if not os.path.exists(module_file_path):
22+
raise FileNotFoundError(module_file_path)
23+
module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
24+
if module_spec is None:
25+
raise FileNotFoundError(
26+
"Unable to find '{}' in '{}'.".format(module_name, module_file_path)
27+
)
28+
module = importlib.util.module_from_spec(module_spec)
29+
return module_spec.loader.exec_module(module)
30+
31+
32+
class TestDocumentationRecipes(ExtTestCase):
33+
def run_test(self, fold: str, name: str, verbose=0) -> int:
34+
ppath = os.environ.get("PYTHONPATH", "")
35+
if not ppath:
36+
os.environ["PYTHONPATH"] = ROOT
37+
elif ROOT not in ppath:
38+
sep = ";" if is_windows() else ":"
39+
os.environ["PYTHONPATH"] = ppath + sep + ROOT
40+
perf = time.perf_counter()
41+
try:
42+
mod = import_source(fold, os.path.splitext(name)[0])
43+
assert mod is not None
44+
except FileNotFoundError:
45+
# try another way
46+
cmds = [sys.executable, "-u", os.path.join(fold, name)]
47+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
48+
res = p.communicate()
49+
out, err = res
50+
st = err.decode("ascii", errors="ignore")
51+
if st and "Traceback" in st:
52+
if '"dot" not found in path.' in st:
53+
# dot not installed, this part
54+
# is tested in onnx framework
55+
if verbose:
56+
print(f"failed: {name!r} due to missing dot.")
57+
return 0
58+
raise AssertionError( # noqa: B904
59+
"Example '{}' (cmd: {} - exec_prefix='{}') "
60+
"failed due to\n{}"
61+
"".format(name, cmds, sys.exec_prefix, st)
62+
)
63+
dt = time.perf_counter() - perf
64+
if verbose:
65+
print(f"{dt:.3f}: run {name!r}")
66+
return 1
67+
68+
@classmethod
69+
def add_test_methods(cls):
70+
this = os.path.abspath(os.path.dirname(__file__))
71+
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "recipes"))
72+
found = os.listdir(fold)
73+
for name in found:
74+
if not name.endswith(".py") or not name.startswith("plot_"):
75+
continue
76+
reason = None
77+
78+
if not reason and not has_torch("4.7"):
79+
reason = "torch<2.7"
80+
81+
if reason:
82+
83+
@unittest.skip(reason)
84+
def _test_(self, name=name):
85+
res = self.run_test(fold, name, verbose=VERBOSE)
86+
self.assertTrue(res)
87+
88+
else:
89+
90+
@ignore_errors(OSError) # connectivity issues
91+
def _test_(self, name=name):
92+
res = self.run_test(fold, name, verbose=VERBOSE)
93+
self.assertTrue(res)
94+
95+
short_name = os.path.split(os.path.splitext(name)[0])[-1]
96+
setattr(cls, f"test_{short_name}", _test_)
97+
98+
99+
TestDocumentationRecipes.add_test_methods()
100+
101+
if __name__ == "__main__":
102+
unittest.main(verbosity=2)

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.3.0"
6+
__version__ = "0.4.0"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,8 @@ def bypass_export_some_errors(
429429
from torch.fx.experimental.symbolic_shapes import ShapeEnv
430430
from .patches.patch_torch import patched_ShapeEnv
431431

432+
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
433+
432434
if verbose:
433435
print(
434436
"[bypass_export_some_errors] assert when a dynamic dimension turns static"
@@ -438,6 +440,11 @@ def bypass_export_some_errors(
438440
f_shape_env__set_replacement = ShapeEnv._set_replacement
439441
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
440442

443+
if verbose:
444+
print("[bypass_export_some_errors] replaces ShapeEnv._log_guard")
445+
f_shape_env__log_guard = ShapeEnv._log_guard
446+
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
447+
441448
if stop_if_static > 1:
442449
if verbose:
443450
print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
@@ -517,6 +524,12 @@ def bypass_export_some_errors(
517524
print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
518525

519526
ShapeEnv._set_replacement = f_shape_env__set_replacement
527+
528+
if verbose:
529+
print("[bypass_export_some_errors] restored ShapeEnv._log_guard")
530+
531+
ShapeEnv._log_guard = f_shape_env__log_guard
532+
520533
if stop_if_static > 1:
521534
if verbose:
522535
print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
import inspect
22
import os
3+
import traceback
34
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
45
import torch
56
from torch._subclasses.fake_tensor import FakeTensorMode
67

78

9+
def retrieve_stacktrace():
10+
"""Retrieves and prints the current stack trace, avoids every torch file."""
11+
rows = []
12+
stack_frames = traceback.extract_stack()
13+
for frame in stack_frames:
14+
filename, lineno, function_name, code_line = frame
15+
if "/torch/" in filename:
16+
continue
17+
rows.append(f"File: {filename}, Line {lineno}, in {function_name}")
18+
if code_line:
19+
rows.append(f" {code_line}")
20+
return "\n".join(rows)
21+
22+
823
def _catch_produce_guards_and_solve_constraints(
924
previous_function: Callable,
1025
fake_mode: "FakeTensorMode", # noqa: F821
@@ -339,3 +354,17 @@ def _set_replacement(
339354
# When specializing 'a == tgt', the equality should be also conveyed to
340355
# Z3, in case an expression uses 'a'.
341356
self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))
357+
358+
def _log_guard(
359+
self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
360+
) -> None:
361+
self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
362+
# It happens too often to be relevant.
363+
# sloc, _maybe_extra_debug = self._get_stack_summary(True)
364+
# warnings.warn(
365+
# f"A guard was added, prefix={prefix!r}, g={g!r}, "
366+
# f"forcing_spec={forcing_spec}, location=\n{sloc}\n"
367+
# f"--stack trace--\n{retrieve_stacktrace()}",
368+
# RuntimeWarning,
369+
# stacklevel=0,
370+
# )

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ignore_missing_imports = true
77
packages = ["onnx_diagnostic"]
88
exclude = [
99
"^_doc/auto_examples", # skips examples in the documentation
10+
"^_doc/auto_recipes", # skips examples in the documentation
1011
"^_doc/conf.py",
1112
"^_doc/examples",
1213
"^_unittests", # skips unit tests

0 commit comments

Comments
 (0)