From 97a2c46141ad9d3638a58090f4c95009e2063cf2 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 7 Apr 2025 17:51:35 +0200 Subject: [PATCH 1/3] add recipes --- _doc/conf.py | 2 + _doc/index.rst | 1 + _doc/recipes/README.txt | 2 + _doc/recipes/plot_dynamic_shapes_nonzero.py | 64 +++++++++++ .../recipes/plot_dynamic_shapes_python_int.py | 86 +++++++++++++++ .../ut_xrun_doc/test_documentation_recipes.py | 102 ++++++++++++++++++ onnx_diagnostic/__init__.py | 2 +- .../onnx_export_errors.py | 13 +++ .../patches/patch_torch.py | 26 +++++ pyproject.toml | 1 + 10 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 _doc/recipes/README.txt create mode 100644 _doc/recipes/plot_dynamic_shapes_nonzero.py create mode 100644 _doc/recipes/plot_dynamic_shapes_python_int.py create mode 100644 _unittests/ut_xrun_doc/test_documentation_recipes.py diff --git a/_doc/conf.py b/_doc/conf.py index f319ccbd..c1c3e749 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -147,10 +147,12 @@ # path to your examples scripts "examples_dirs": [ os.path.join(os.path.dirname(__file__), "examples"), + os.path.join(os.path.dirname(__file__), "recipes"), ], # path where to save gallery generated examples "gallery_dirs": [ "auto_examples", + "auto_recipes", ], # no parallelization to avoid conflict with environment variables "parallel": 1, diff --git a/_doc/index.rst b/_doc/index.rst index b32f8a8c..c7d2de9d 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -31,6 +31,7 @@ Source are `sdpython/onnx-diagnostic api/index cmds/index auto_examples/index + auto_recipes/index .. toctree:: :maxdepth: 1 diff --git a/_doc/recipes/README.txt b/_doc/recipes/README.txt new file mode 100644 index 00000000..a48db993 --- /dev/null +++ b/_doc/recipes/README.txt @@ -0,0 +1,2 @@ +Common Export Issues +==================== diff --git a/_doc/recipes/plot_dynamic_shapes_nonzero.py b/_doc/recipes/plot_dynamic_shapes_nonzero.py new file mode 100644 index 00000000..a8e7cca3 --- /dev/null +++ b/_doc/recipes/plot_dynamic_shapes_nonzero.py @@ -0,0 +1,64 @@ +""" +Half certain nonzero +==================== + +:func:`torch.nonzero` returns the indices or the first zero found +in a tensor. The output shape is unknown in the generic case +but... If you have a 2D tensor with at least a nonzero value +in every row, you can guess the dimension. But :func:`torch.export.export` +does not know what you know. + + +A Model ++++++++ +""" + +import torch +from onnx_diagnostic import doc + + +class Model(torch.nn.Module): + def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=0): + chunk_start_idx = torch.Tensor(chunk_start_idx).long() + start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0) + end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0) + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + def forward(self, x): + return self.adaptive_enc_mask(x.shape[1], []) + + +model = Model() +x = torch.rand((5, 8)) +y = model(x) +print(f"x.shape={x.shape}, y.shape={y.shape}") + +# %% +# Export +# ++++++ + +DYN = torch.export.Dim.DYNAMIC +ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),)) +print(ep) + + +# %% +# We can see the following line in the exported program. +# It tells what it cannot verify. +# ``torch.ops.aten._assert_scalar.default(eq,`` +# ``"Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");`` + + +# %% +doc.plot_legend("dynamic shapes\nnonzero", "dynamic shapes", "yellow") diff --git a/_doc/recipes/plot_dynamic_shapes_python_int.py b/_doc/recipes/plot_dynamic_shapes_python_int.py new file mode 100644 index 00000000..e9c385a0 --- /dev/null +++ b/_doc/recipes/plot_dynamic_shapes_python_int.py @@ -0,0 +1,86 @@ +""" +Do not use python int with dynamic shape +========================================= + +:func:`torch.export.export` uses :class:`torch.SymInt` to operate on shapes and +optimizes the graph it produces. It checks if two tensors share the same dimension, +if the shapes can be broadcast, ... To do that, python types must not be used +or the algorithm looses information. + +Wrong Model ++++++++++++ +""" + +import math +import torch +from onnx_diagnostic import doc +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors + + +class Model(torch.nn.Module): + def dim(self, i, divisor): + return int(math.ceil(i / divisor)) # noqa: RUF046 + + def forward(self, x): + new_shape = (self.dim(x.shape[0], 8), x.shape[1]) + return torch.zeros(new_shape) + + +model = Model() +x = torch.rand((10, 15)) +y = model(x) +print(f"x.shape={x.shape}, y.shape={y.shape}") + +# %% +# Export +# ++++++ + +DYN = torch.export.Dim.DYNAMIC +ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),)) +print(ep) + +# %% +# The last dimension became static. We must not use int. +# :func:`math.ceil` should be avoided as well since it is a python operation. +# The exporter may fail to detect it is operating on shapes. +# +# Rewrite +# +++++++ + + +class RewrittenModel(torch.nn.Module): + def dim(self, i, divisor): + return (i + divisor - 1) // divisor + + def forward(self, x): + new_shape = (self.dim(x.shape[0], 8), x.shape[1]) + return torch.zeros(new_shape) + + +rewritten_model = RewrittenModel() +y = rewritten_model(x) +print(f"x.shape={x.shape}, y.shape={y.shape}") + +# %% +# Export +# ++++++ + +ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=((DYN, DYN),)) +print(ep) + + +# %% +# Find the error +# ++++++++++++++ +# +# Function :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` +# has a parameter ``stop_if_static`` which patches torch to raise exception +# when something like that is happening. + + +with bypass_export_some_errors(stop_if_static=True): + ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),)) + print(ep) + +# %% +doc.plot_legend("dynamic shapes\ndo not cast to\npython int", "dynamic shapes", "yellow") diff --git a/_unittests/ut_xrun_doc/test_documentation_recipes.py b/_unittests/ut_xrun_doc/test_documentation_recipes.py new file mode 100644 index 00000000..4ea8c171 --- /dev/null +++ b/_unittests/ut_xrun_doc/test_documentation_recipes.py @@ -0,0 +1,102 @@ +import unittest +import os +import sys +import importlib.util +import subprocess +import time +from onnx_diagnostic import __file__ as onnx_diagnostic_file +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + is_windows, + has_torch, + ignore_errors, +) + + +VERBOSE = 0 +ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_diagnostic_file, "..", ".."))) + + +def import_source(module_file_path, module_name): + if not os.path.exists(module_file_path): + raise FileNotFoundError(module_file_path) + module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) + if module_spec is None: + raise FileNotFoundError( + "Unable to find '{}' in '{}'.".format(module_name, module_file_path) + ) + module = importlib.util.module_from_spec(module_spec) + return module_spec.loader.exec_module(module) + + +class TestDocumentationRecipes(ExtTestCase): + def run_test(self, fold: str, name: str, verbose=0) -> int: + ppath = os.environ.get("PYTHONPATH", "") + if not ppath: + os.environ["PYTHONPATH"] = ROOT + elif ROOT not in ppath: + sep = ";" if is_windows() else ":" + os.environ["PYTHONPATH"] = ppath + sep + ROOT + perf = time.perf_counter() + try: + mod = import_source(fold, os.path.splitext(name)[0]) + assert mod is not None + except FileNotFoundError: + # try another way + cmds = [sys.executable, "-u", os.path.join(fold, name)] + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + out, err = res + st = err.decode("ascii", errors="ignore") + if st and "Traceback" in st: + if '"dot" not found in path.' in st: + # dot not installed, this part + # is tested in onnx framework + if verbose: + print(f"failed: {name!r} due to missing dot.") + return 0 + raise AssertionError( # noqa: B904 + "Example '{}' (cmd: {} - exec_prefix='{}') " + "failed due to\n{}" + "".format(name, cmds, sys.exec_prefix, st) + ) + dt = time.perf_counter() - perf + if verbose: + print(f"{dt:.3f}: run {name!r}") + return 1 + + @classmethod + def add_test_methods(cls): + this = os.path.abspath(os.path.dirname(__file__)) + fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "recipes")) + found = os.listdir(fold) + for name in found: + if not name.endswith(".py") or not name.startswith("plot_"): + continue + reason = None + + if not reason and not has_torch("4.7"): + reason = "torch<2.7" + + if reason: + + @unittest.skip(reason) + def _test_(self, name=name): + res = self.run_test(fold, name, verbose=VERBOSE) + self.assertTrue(res) + + else: + + @ignore_errors(OSError) # connectivity issues + def _test_(self, name=name): + res = self.run_test(fold, name, verbose=VERBOSE) + self.assertTrue(res) + + short_name = os.path.split(os.path.splitext(name)[0])[-1] + setattr(cls, f"test_{short_name}", _test_) + + +TestDocumentationRecipes.add_test_methods() + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 6b0d3c8e..ce77dbc9 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.3.0" +__version__ = "0.4.0" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 9964a4da..2b53f5e2 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -429,6 +429,8 @@ def bypass_export_some_errors( from torch.fx.experimental.symbolic_shapes import ShapeEnv from .patches.patch_torch import patched_ShapeEnv + ShapeEnv._log_guard_remember = ShapeEnv._log_guard + if verbose: print( "[bypass_export_some_errors] assert when a dynamic dimension turns static" @@ -438,6 +440,11 @@ def bypass_export_some_errors( f_shape_env__set_replacement = ShapeEnv._set_replacement ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement + if verbose: + print("[bypass_export_some_errors] replaces ShapeEnv._log_guard") + f_shape_env__log_guard = ShapeEnv._log_guard + ShapeEnv._log_guard = patched_ShapeEnv._log_guard + if stop_if_static > 1: if verbose: print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen") @@ -517,6 +524,12 @@ def bypass_export_some_errors( print("[bypass_export_some_errors] restored ShapeEnv._set_replacement") ShapeEnv._set_replacement = f_shape_env__set_replacement + + if verbose: + print("[bypass_export_some_errors] restored ShapeEnv._log_guard") + + ShapeEnv._log_guard = f_shape_env__log_guard + if stop_if_static > 1: if verbose: print("[bypass_export_some_errors] restored ShapeEnv._check_frozen") diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index b487e372..1e030e78 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -1,10 +1,25 @@ import inspect import os +import traceback from typing import Any, Callable, Dict, List, Sequence, Tuple, Union import torch from torch._subclasses.fake_tensor import FakeTensorMode +def retrieve_stacktrace(): + """Retrieves and prints the current stack trace, avoids every torch file.""" + rows = [] + stack_frames = traceback.extract_stack() + for frame in stack_frames: + filename, lineno, function_name, code_line = frame + if "/torch/" in filename: + continue + rows.append(f"File: {filename}, Line {lineno}, in {function_name}") + if code_line: + rows.append(f" {code_line}") + return "\n".join(rows) + + def _catch_produce_guards_and_solve_constraints( previous_function: Callable, fake_mode: "FakeTensorMode", # noqa: F821 @@ -339,3 +354,14 @@ def _set_replacement( # When specializing 'a == tgt', the equality should be also conveyed to # Z3, in case an expression uses 'a'. self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) + + def _log_guard( + self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821 + ) -> None: + self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec) + sloc, _maybe_extra_debug = self._get_stack_summary(True) + raise AssertionError( + f"A guard was added, prefix={prefix!r}, g={g!r}, " + f"forcing_spec={forcing_spec}, location=\n{sloc}\n" + f"--stack trace--\n{retrieve_stacktrace()}" + ) diff --git a/pyproject.toml b/pyproject.toml index ba920b32..79fdfd4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ ignore_missing_imports = true packages = ["onnx_diagnostic"] exclude = [ "^_doc/auto_examples", # skips examples in the documentation + "^_doc/auto_recipes", # skips examples in the documentation "^_doc/conf.py", "^_doc/examples", "^_unittests", # skips unit tests From 78b99c3d9f10b0fa34627189c3106a4df798c9f9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 7 Apr 2025 18:12:26 +0200 Subject: [PATCH 2/3] stacklevel --- .../torch_export_patches/patches/patch_torch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 1e030e78..62f3b7f3 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -1,6 +1,7 @@ import inspect import os import traceback +import warnings from typing import Any, Callable, Dict, List, Sequence, Tuple, Union import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -360,8 +361,10 @@ def _log_guard( ) -> None: self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec) sloc, _maybe_extra_debug = self._get_stack_summary(True) - raise AssertionError( + warnings.warn( f"A guard was added, prefix={prefix!r}, g={g!r}, " f"forcing_spec={forcing_spec}, location=\n{sloc}\n" - f"--stack trace--\n{retrieve_stacktrace()}" + f"--stack trace--\n{retrieve_stacktrace()}", + RuntimeWarning, + stacklevel=0, ) From 5ccd38c9f8f3232bd8a175fdd734ce91ad4ba6a1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 7 Apr 2025 18:14:30 +0200 Subject: [PATCH 3/3] fix issue --- .../patches/patch_torch.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 62f3b7f3..0df84c00 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -1,7 +1,6 @@ import inspect import os import traceback -import warnings from typing import Any, Callable, Dict, List, Sequence, Tuple, Union import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -360,11 +359,12 @@ def _log_guard( self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821 ) -> None: self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec) - sloc, _maybe_extra_debug = self._get_stack_summary(True) - warnings.warn( - f"A guard was added, prefix={prefix!r}, g={g!r}, " - f"forcing_spec={forcing_spec}, location=\n{sloc}\n" - f"--stack trace--\n{retrieve_stacktrace()}", - RuntimeWarning, - stacklevel=0, - ) + # It happens too often to be relevant. + # sloc, _maybe_extra_debug = self._get_stack_summary(True) + # warnings.warn( + # f"A guard was added, prefix={prefix!r}, g={g!r}, " + # f"forcing_spec={forcing_spec}, location=\n{sloc}\n" + # f"--stack trace--\n{retrieve_stacktrace()}", + # RuntimeWarning, + # stacklevel=0, + # )