diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 2f02ddd7..42bc5036 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.5.0 +++++ +* :pr:`93`: introduce patched expression to get around annoying export issues * :pr:`92`: support errors distribution in max_diff * :pr:`91`: enable strings in ``guess_dynamic_shapes`` * :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models diff --git a/_doc/api/helpers/helper.rst b/_doc/api/helpers/helper.rst index 2def9207..8c69e27d 100644 --- a/_doc/api/helpers/helper.rst +++ b/_doc/api/helpers/helper.rst @@ -4,4 +4,4 @@ onnx_diagnostic.helpers.helper .. automodule:: onnx_diagnostic.helpers.helper :no-undoc-members: - :exclude-members: max_diff, string_diff, string_sig, string_type + :exclude-members: flatten_object, max_diff, string_diff, string_sig, string_type diff --git a/_doc/api/helpers/index.rst b/_doc/api/helpers/index.rst index 48195bb9..ede2174e 100644 --- a/_doc/api/helpers/index.rst +++ b/_doc/api/helpers/index.rst @@ -18,6 +18,8 @@ onnx_diagnostic.helpers rt_helper torch_test_helper +.. autofunction:: onnx_diagnostic.helpers.flatten_object + .. autofunction:: onnx_diagnostic.helpers.max_diff .. autofunction:: onnx_diagnostic.helpers.string_diff diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index d47ec2f1..60c51d1d 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -6,6 +6,7 @@ onnx_diagnostic.torch_export_patches :caption: submodules patches/index + patch_expressions patch_inputs patch_module diff --git a/_doc/api/torch_export_patches/patch_expressions.rst b/_doc/api/torch_export_patches/patch_expressions.rst new file mode 100644 index 00000000..f8050ec7 --- /dev/null +++ b/_doc/api/torch_export_patches/patch_expressions.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.patch_expressions +====================================================== + +.. automodule:: onnx_diagnostic.torch_export_patches.patch_expressions + :members: + :no-undoc-members: diff --git a/_unittests/ut_helpers/test_mini_onnx_builder.py b/_unittests/ut_helpers/test_mini_onnx_builder.py index 78005203..19ae1d18 100644 --- a/_unittests/ut_helpers/test_mini_onnx_builder.py +++ b/_unittests/ut_helpers/test_mini_onnx_builder.py @@ -151,6 +151,48 @@ def test_mini_onnx_builder_transformers_sep(self): restored = create_input_tensors_from_onnx_model(model, sep="#") self.assertEqualAny(inputs, restored) + def test_specific_data(self): + data = { + ("amain", 0, "I"): ( + ( + torch.rand((2, 16, 3, 448, 448), dtype=torch.float16), + torch.rand((2, 16, 32, 32), dtype=torch.float16), + torch.rand((2, 2)).to(torch.int64), + ), + {}, + ), + } + model = create_onnx_model_from_input_tensors(data) + shapes = [ + tuple(d.dim_value for d in i.type.tensor_type.shape.dim) + for i in model.graph.output + ] + self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)]) + names = [i.name for i in model.graph.output] + self.assertEqual( + [ + "dict._((amain,0,I))___tuple_0___tuple_0___tensor", + "dict._((amain,0,I))___tuple_0___tuple_1___tensor", + "dict._((amain,0,I))___tuple_0___tuple_2.___tensor", + "dict._((amain,0,I))___tuple_1.___dict.___empty", + ], + names, + ) + shapes = [tuple(i.dims) for i in model.graph.initializer] + self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)]) + names = [i.name for i in model.graph.initializer] + self.assertEqual( + [ + "t_dict._((amain,0,I))___tuple_0___tuple_0___tensor", + "t_dict._((amain,0,I))___tuple_0___tuple_1___tensor", + "t_dict._((amain,0,I))___tuple_0___tuple_2.___tensor", + "t_dict._((amain,0,I))___tuple_1.___dict.___empty", + ], + names, + ) + restored = create_input_tensors_from_onnx_model(model) + self.assertEqualAny(data, restored) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_expressions.py b/_unittests/ut_torch_export_patches/test_patch_expressions.py new file mode 100644 index 00000000..610d9757 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_expressions.py @@ -0,0 +1,46 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.torch_export_patches.patch_expressions import ( + _iterate_patched_expressions, + register_patched_expressions, + patched_selector, + patched_float_arange, +) +from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting + + +class TestOnnxExportErrors(ExtTestCase): + + @classmethod + def setUp(cls): + register_patched_expressions() + + def test_patched_expressions(self): + res = list(_iterate_patched_expressions()) + names = {_[0] for _ in res} + self.assertIn("float_arange", names) + + def test_float_arange(self): + _T = torch.tensor + res = torch.arange(4, 6, 0.234) + got = torch.arange(4, 6, 0.234, dtype=torch.float32, device=torch.device("cpu")) + self.assertEqualArray(res, got) + got = torch.ops.patched.float_arange(_T(4.0), _T(6.0), _T(0.234)) + self.assertEqualArray(res, got, atol=1e-5) + got = patched_selector( + (lambda a, b, c: torch.arange(a.item(), b.item(), c.item())), + torch.ops.patched.float_arange, + )(_T(4.0), _T(6.0), _T(0.234)) + self.assertEqualArray(res, got, atol=1e-5) + got = patched_float_arange(_T(4.0), _T(6.0), _T(0.234)) + self.assertEqualArray(res, got, atol=1e-5) + with fake_torchdynamo_exporting(): + got = patched_selector(None, torch.ops.patched.float_arange)( + _T(4.0), _T(6.0), _T(0.234) + ) + self.assertEqualArray(res, got, atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_loops.py b/_unittests/ut_torch_export_patches/test_patch_loops.py new file mode 100644 index 00000000..346d9cb4 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_loops.py @@ -0,0 +1,124 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.helpers.torch_test_helper import ( + is_torchdynamo_exporting, + fake_torchdynamo_exporting, +) +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches.patch_expressions import ( + _iterate_patched_expressions, + register_patched_expressions, + patched_float_arange, +) + + +class TestOnnxExportErrors(ExtTestCase): + + def test_patched_expressions(self): + res = list(_iterate_patched_expressions()) + names = {_[0] for _ in res} + self.assertIn("float_arange", names) + + @requires_torch("2.8") + def test_filter_position_ids(self): + + def filter_position_ids( + patch_attention_mask: torch.Tensor, + position_ids: torch.Tensor, + boundaries: torch.Tensor, + num_patches_per_side: int, + ): + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum()) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum()) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = ( + bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids + return position_ids + + def float_arange(start, end, step): + length = torch.sym_int((end - start) / step + (step * (1 - 1e-6))) + torch._check(length > 0) + res = torch.arange(0, length) + torch._check(res.is_contiguous()) + fres = res.to(torch.float32) + fstart = torch.tensor(start, dtype=torch.float32) + return fres + fstart + + def scan_filter_position_ids( + patch_attention_mask: torch.Tensor, + position_ids: torch.Tensor, + boundaries: torch.Tensor, + num_patches_per_side: int, + ): + + def body(p_attn_mask, position_ids_row): + h_len = torch.tensor(1) / p_attn_mask[:, 0].sum() + w_len = torch.tensor(1) / p_attn_mask[0].sum() + fractional_coords_h = patched_float_arange( + torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len + ) + fractional_coords_w = patched_float_arange( + torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len + ) + + # torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item()) + # torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item()) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = ( + bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w + ).flatten() + + row = position_ids_row.clone() + row[p_attn_mask.view(-1)] = pos_ids + return [row] + + return torch.ops.higher_order.scan( + body, [], [patch_attention_mask, position_ids], additional_inputs=[] + ) + + class Model(torch.nn.Module): + def forward(self, patch_attention_mask, position_ids, boundaries): + if is_torchdynamo_exporting(): + res = scan_filter_position_ids( + patch_attention_mask, position_ids, boundaries, 32 + ) + return res[0] + return filter_position_ids(patch_attention_mask, position_ids, boundaries, 32) + + # 32 + # T9s32x32x32[False,True:A0.978515625], + # T7s32x1024[0,0:A0.0], + # T1s31[0.03125,0.96875:A0.5]] + register_patched_expressions() + patch_attention_mask = torch.randint(0, 20, (32, 32, 32)) >= 1 + patch_attention_mask[:, :, :] = True + position_ids = torch.zeros((32, 1024), dtype=torch.int64) + boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1] + inputs = (patch_attention_mask, position_ids, boundaries) + model = Model() + expected = model(*inputs) + with fake_torchdynamo_exporting(): + got = model(*inputs) + self.assertEqual(type(expected), type(got)) + self.assertEqual( + string_type(expected, with_shape=True), string_type(got, with_shape=True) + ) + self.assertEqualArray(expected, got) + + DYN = torch.export.Dim.DYNAMIC + ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN})) + self.assertEqualArray(expected, ep.module()(*inputs)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/__init__.py b/onnx_diagnostic/helpers/__init__.py index 6467d287..2eb5c8b0 100644 --- a/onnx_diagnostic/helpers/__init__.py +++ b/onnx_diagnostic/helpers/__init__.py @@ -1 +1 @@ -from .helper import max_diff, string_diff, string_sig, string_type +from .helper import flatten_object, max_diff, string_diff, string_sig, string_type diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index 8d868d08..df3df7a8 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -139,6 +139,9 @@ def append_output_initializer( return init_name = f"t_{name}" + assert ( + init_name not in self.initializers_dict + ), f"name={init_name!r} already in {sorted(self.initializers_dict)}" self.initializers_dict[init_name] = tensor shape = tuple(map(int, tensor.shape)) self.outputs.append( @@ -324,10 +327,10 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator: for i, o in enumerate(obj): if i == len(obj) - 1: for p, oo in _flatten_iterator(o, sep): - yield f"tuple.{sep}{p}", oo + yield f"tuple_{i}.{sep}{p}", oo else: for p, oo in _flatten_iterator(o, sep): - yield f"tuple{sep}{p}", oo + yield f"tuple_{i}{sep}{p}", oo elif isinstance(obj, list): if not obj: yield f"list.{sep}empty", None @@ -335,10 +338,10 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator: for i, o in enumerate(obj): if i == len(obj) - 1: for p, oo in _flatten_iterator(o, sep): - yield f"list.{sep}{p}", oo + yield f"list_{i}.{sep}{p}", oo else: for p, oo in _flatten_iterator(o, sep): - yield f"list{sep}{p}", oo + yield f"list_{i}{sep}{p}", oo elif isinstance(obj, dict): if not obj: yield f"dict.{sep}empty", None diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 39a0d45d..54a7874d 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -101,11 +101,27 @@ def __init__( providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: raise ValueError(f"Unexpected value for providers={providers!r}") - sess = onnxruntime.InferenceSession( - sess if isinstance(sess, str) else sess.SerializeToString(), - session_options, - providers=providers, - ) + try: + sess = onnxruntime.InferenceSession( + sess if isinstance(sess, str) else sess.SerializeToString(), + session_options, + providers=providers, + ) + except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e: + if isinstance(sess, onnx.ModelProto): + debug_path = "_debug_onnxruntine_evaluator_failure.onnx" + onnx.save( + sess, + debug_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) + else: + debug_path = sess + raise RuntimeError( + f"Unable to create a session stored in {debug_path!r}), " + f"providers={providers}" + ) from e else: assert ( session_options is None diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index aa349eb7..fd69f38b 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -151,11 +151,25 @@ def forward(self, x, y): onnx.save( proto, dump_file, - save_as_external_data=False, - all_tensors_to_one_file=True, + save_as_external_data=True, + all_tensors_to_one_file=False, ) +@contextlib.contextmanager +def fake_torchdynamo_exporting(): + """ + Sets ``torch.compiler._is_exporting_flag`` to True to trigger + pieces of code only enabled during export. + """ + memorize = torch.compiler._is_exporting_flag + torch.compiler._is_exporting_flag = True + try: + yield + finally: + torch.compiler._is_exporting_flag = memorize + + def is_torchdynamo_exporting() -> bool: """ Tells if :epkg:`torch` is exporting a model. diff --git a/onnx_diagnostic/torch_export_patches/patch_expressions.py b/onnx_diagnostic/torch_export_patches/patch_expressions.py new file mode 100644 index 00000000..a25f0d5f --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patch_expressions.py @@ -0,0 +1,108 @@ +from typing import Callable, Set +import torch +from ..helpers.torch_test_helper import is_torchdynamo_exporting + + +def make_undefined_dimension(i: int) -> torch.SymInt: + """ + Uses for a custom op when a new dimension must be introduced to bypass + some verification. The following function creates a dummy output + with a dimension based on the content. + + .. code-block:: python + + def symbolic_shape(x, y): + return torch.empty( + x.shape[0], + make_undefined_dimension(min(x.shape[1], y[0])), + ) + """ + try: + ti = int(i) + except: # noqa: E722 + ti = 10 + t = torch.ones((ti * 2,)) + t[:ti] = 0 + res = torch.nonzero(t).shape[0] + return res + + +def _patched_float_arange( + start: torch.Tensor, end: torch.Tensor, step: torch.Tensor +) -> torch.Tensor: + """Float arange.""" + return torch.arange( + float(start.item()), + float(end.item()), + float(step.item()), + dtype=start.dtype, + device=start.device, + ) + + +def _patched_float_arange_shape(start, end, step): + # Fails because: + # Did you accidentally call new_dynamic_size() or item() + # more times than you needed to in your fake implementation? + # try: + # n = math.ceil(((end - start) / step).item()) + # except: # noqa: E722 + # n = 10 + n = 10 + return torch.empty((make_undefined_dimension(n),), dtype=start.dtype, device=start.device) + + +def _iterate_patched_expressions(): + glo = globals().copy() + for k, _v in glo.items(): + if k.startswith("_patched_") and not k.endswith("_shape"): + name = k + yield k[len("_patched_") :], glo[name], glo[f"{name}_shape"] + + +_registered: Set[str] = set() + + +def _register_patched_expression( + fct: Callable, fct_shape: Callable, namespace: str, fname: str +): + schema_str = torch.library.infer_schema(fct, mutates_args=()) + custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) + custom_def.register_kernel("cpu")(fct) + custom_def._abstract_fn = fct_shape + + +def register_patched_expressions(namespace: str = "patched"): + """ + Registers as custom ops known expressions failing due to dynamic shapes. + + .. runpython:: + :showcode: + + import pprint + from onnx_diagnostic.torch_export_patches.patch_expressions import ( + _iterate_patched_expressions, + ) + + pprint.pprint([name for name, _f, _fsh in _iterate_patched_expressions()]) + """ + for name, f, fsh in _iterate_patched_expressions(): + if name not in _registered: + _register_patched_expression(f, fsh, namespace, name) + _registered.add(name) + + +def patched_selector(fct: Callable, patched_fct: Callable) -> Callable: + """ + Returns **fct** if the model is being executed or + **patched_fct** if it is being exported. + """ + return patched_fct if is_torchdynamo_exporting() else fct + + +def patched_float_arange(start, end, step): + """Patched arange when start, end, step are floats.""" + if is_torchdynamo_exporting(): + return torch.ops.patched.float_arange(start, end, step) + else: + return torch.arange(start, end, step)