Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@
# path to your examples scripts
"examples_dirs": [
os.path.join(os.path.dirname(__file__), "examples"),
os.path.join(os.path.dirname(__file__), "recipes"),
],
# path where to save gallery generated examples
"gallery_dirs": [
"auto_examples",
"auto_recipes",
],
# no parallelization to avoid conflict with environment variables
"parallel": 1,
Expand Down
1 change: 1 addition & 0 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Source are `sdpython/onnx-diagnostic
api/index
cmds/index
auto_examples/index
auto_recipes/index

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 2 additions & 0 deletions _doc/recipes/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Common Export Issues
====================
64 changes: 64 additions & 0 deletions _doc/recipes/plot_dynamic_shapes_nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Half certain nonzero
====================

:func:`torch.nonzero` returns the indices or the first zero found
in a tensor. The output shape is unknown in the generic case
but... If you have a 2D tensor with at least a nonzero value
in every row, you can guess the dimension. But :func:`torch.export.export`
does not know what you know.


A Model
+++++++
"""

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=0):
chunk_start_idx = torch.Tensor(chunk_start_idx).long()
start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0)
end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0)
seq_range = torch.arange(0, x_len).unsqueeze(-1)
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
idx_left = idx - left_window
idx_left[idx_left < 0] = 0
boundary_left = start_pad[idx_left]
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
idx_right = idx + right_window
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
boundary_right = end_pad[idx_right]
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
return mask_left & mask_right

def forward(self, x):
return self.adaptive_enc_mask(x.shape[1], [])


model = Model()
x = torch.rand((5, 8))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")

# %%
# Export
# ++++++

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
print(ep)


# %%
# We can see the following line in the exported program.
# It tells what it cannot verify.
# ``torch.ops.aten._assert_scalar.default(eq,``
# ``"Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");``


# %%
doc.plot_legend("dynamic shapes\nnonzero", "dynamic shapes", "yellow")
86 changes: 86 additions & 0 deletions _doc/recipes/plot_dynamic_shapes_python_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Do not use python int with dynamic shape
=========================================

:func:`torch.export.export` uses :class:`torch.SymInt` to operate on shapes and
optimizes the graph it produces. It checks if two tensors share the same dimension,
if the shapes can be broadcast, ... To do that, python types must not be used
or the algorithm looses information.

Wrong Model
+++++++++++
"""

import math
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class Model(torch.nn.Module):
def dim(self, i, divisor):
return int(math.ceil(i / divisor)) # noqa: RUF046

def forward(self, x):
new_shape = (self.dim(x.shape[0], 8), x.shape[1])
return torch.zeros(new_shape)


model = Model()
x = torch.rand((10, 15))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")

# %%
# Export
# ++++++

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
print(ep)

# %%
# The last dimension became static. We must not use int.
# :func:`math.ceil` should be avoided as well since it is a python operation.
# The exporter may fail to detect it is operating on shapes.
#
# Rewrite
# +++++++


class RewrittenModel(torch.nn.Module):
def dim(self, i, divisor):
return (i + divisor - 1) // divisor

def forward(self, x):
new_shape = (self.dim(x.shape[0], 8), x.shape[1])
return torch.zeros(new_shape)


rewritten_model = RewrittenModel()
y = rewritten_model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")

# %%
# Export
# ++++++

ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=((DYN, DYN),))
print(ep)


# %%
# Find the error
# ++++++++++++++
#
# Function :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
# has a parameter ``stop_if_static`` which patches torch to raise exception
# when something like that is happening.


with bypass_export_some_errors(stop_if_static=True):
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
print(ep)

# %%
doc.plot_legend("dynamic shapes\ndo not cast to\npython int", "dynamic shapes", "yellow")
102 changes: 102 additions & 0 deletions _unittests/ut_xrun_doc/test_documentation_recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
import os
import sys
import importlib.util
import subprocess
import time
from onnx_diagnostic import __file__ as onnx_diagnostic_file
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
is_windows,
has_torch,
ignore_errors,
)


VERBOSE = 0
ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_diagnostic_file, "..", "..")))


def import_source(module_file_path, module_name):
if not os.path.exists(module_file_path):
raise FileNotFoundError(module_file_path)
module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if module_spec is None:
raise FileNotFoundError(
"Unable to find '{}' in '{}'.".format(module_name, module_file_path)
)
module = importlib.util.module_from_spec(module_spec)
return module_spec.loader.exec_module(module)


class TestDocumentationRecipes(ExtTestCase):
def run_test(self, fold: str, name: str, verbose=0) -> int:
ppath = os.environ.get("PYTHONPATH", "")
if not ppath:
os.environ["PYTHONPATH"] = ROOT
elif ROOT not in ppath:
sep = ";" if is_windows() else ":"
os.environ["PYTHONPATH"] = ppath + sep + ROOT
perf = time.perf_counter()
try:
mod = import_source(fold, os.path.splitext(name)[0])
assert mod is not None
except FileNotFoundError:
# try another way
cmds = [sys.executable, "-u", os.path.join(fold, name)]
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
res = p.communicate()
out, err = res
st = err.decode("ascii", errors="ignore")
if st and "Traceback" in st:
if '"dot" not found in path.' in st:
# dot not installed, this part
# is tested in onnx framework
if verbose:
print(f"failed: {name!r} due to missing dot.")
return 0
raise AssertionError( # noqa: B904
"Example '{}' (cmd: {} - exec_prefix='{}') "
"failed due to\n{}"
"".format(name, cmds, sys.exec_prefix, st)
)
dt = time.perf_counter() - perf
if verbose:
print(f"{dt:.3f}: run {name!r}")
return 1

@classmethod
def add_test_methods(cls):
this = os.path.abspath(os.path.dirname(__file__))
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "recipes"))
found = os.listdir(fold)
for name in found:
if not name.endswith(".py") or not name.startswith("plot_"):
continue
reason = None

if not reason and not has_torch("4.7"):
reason = "torch<2.7"

if reason:

@unittest.skip(reason)
def _test_(self, name=name):
res = self.run_test(fold, name, verbose=VERBOSE)
self.assertTrue(res)

else:

@ignore_errors(OSError) # connectivity issues
def _test_(self, name=name):
res = self.run_test(fold, name, verbose=VERBOSE)
self.assertTrue(res)

short_name = os.path.split(os.path.splitext(name)[0])[-1]
setattr(cls, f"test_{short_name}", _test_)


TestDocumentationRecipes.add_test_methods()

if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Functions, classes to dig into a model when this one is right, slow, wrong...
"""

__version__ = "0.3.0"
__version__ = "0.4.0"
__author__ = "Xavier Dupré"
13 changes: 13 additions & 0 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ def bypass_export_some_errors(
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from .patches.patch_torch import patched_ShapeEnv

ShapeEnv._log_guard_remember = ShapeEnv._log_guard

if verbose:
print(
"[bypass_export_some_errors] assert when a dynamic dimension turns static"
Expand All @@ -438,6 +440,11 @@ def bypass_export_some_errors(
f_shape_env__set_replacement = ShapeEnv._set_replacement
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement

if verbose:
print("[bypass_export_some_errors] replaces ShapeEnv._log_guard")
f_shape_env__log_guard = ShapeEnv._log_guard
ShapeEnv._log_guard = patched_ShapeEnv._log_guard

if stop_if_static > 1:
if verbose:
print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
Expand Down Expand Up @@ -517,6 +524,12 @@ def bypass_export_some_errors(
print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")

ShapeEnv._set_replacement = f_shape_env__set_replacement

if verbose:
print("[bypass_export_some_errors] restored ShapeEnv._log_guard")

ShapeEnv._log_guard = f_shape_env__log_guard

if stop_if_static > 1:
if verbose:
print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
Expand Down
29 changes: 29 additions & 0 deletions onnx_diagnostic/torch_export_patches/patches/patch_torch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
import inspect
import os
import traceback
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
import torch
from torch._subclasses.fake_tensor import FakeTensorMode


def retrieve_stacktrace():
"""Retrieves and prints the current stack trace, avoids every torch file."""
rows = []
stack_frames = traceback.extract_stack()
for frame in stack_frames:
filename, lineno, function_name, code_line = frame
if "/torch/" in filename:
continue
rows.append(f"File: {filename}, Line {lineno}, in {function_name}")
if code_line:
rows.append(f" {code_line}")
return "\n".join(rows)


def _catch_produce_guards_and_solve_constraints(
previous_function: Callable,
fake_mode: "FakeTensorMode", # noqa: F821
Expand Down Expand Up @@ -339,3 +354,17 @@ def _set_replacement(
# When specializing 'a == tgt', the equality should be also conveyed to
# Z3, in case an expression uses 'a'.
self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))

def _log_guard(
self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
) -> None:
self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
# It happens too often to be relevant.
# sloc, _maybe_extra_debug = self._get_stack_summary(True)
# warnings.warn(
# f"A guard was added, prefix={prefix!r}, g={g!r}, "
# f"forcing_spec={forcing_spec}, location=\n{sloc}\n"
# f"--stack trace--\n{retrieve_stacktrace()}",
# RuntimeWarning,
# stacklevel=0,
# )
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ignore_missing_imports = true
packages = ["onnx_diagnostic"]
exclude = [
"^_doc/auto_examples", # skips examples in the documentation
"^_doc/auto_recipes", # skips examples in the documentation
"^_doc/conf.py",
"^_doc/examples",
"^_unittests", # skips unit tests
Expand Down
Loading