Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.5.0
+++++

* :pr:`93`: introduce patched expression to get around annoying export issues
* :pr:`92`: support errors distribution in max_diff
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
Expand Down
2 changes: 1 addition & 1 deletion _doc/api/helpers/helper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ onnx_diagnostic.helpers.helper

.. automodule:: onnx_diagnostic.helpers.helper
:no-undoc-members:
:exclude-members: max_diff, string_diff, string_sig, string_type
:exclude-members: flatten_object, max_diff, string_diff, string_sig, string_type
2 changes: 2 additions & 0 deletions _doc/api/helpers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ onnx_diagnostic.helpers
rt_helper
torch_test_helper

.. autofunction:: onnx_diagnostic.helpers.flatten_object

.. autofunction:: onnx_diagnostic.helpers.max_diff

.. autofunction:: onnx_diagnostic.helpers.string_diff
Expand Down
1 change: 1 addition & 0 deletions _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ onnx_diagnostic.torch_export_patches
:caption: submodules

patches/index
patch_expressions
patch_inputs
patch_module

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/torch_export_patches/patch_expressions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_export_patches.patch_expressions
======================================================

.. automodule:: onnx_diagnostic.torch_export_patches.patch_expressions
:members:
:no-undoc-members:
42 changes: 42 additions & 0 deletions _unittests/ut_helpers/test_mini_onnx_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,48 @@ def test_mini_onnx_builder_transformers_sep(self):
restored = create_input_tensors_from_onnx_model(model, sep="#")
self.assertEqualAny(inputs, restored)

def test_specific_data(self):
data = {
("amain", 0, "I"): (
(
torch.rand((2, 16, 3, 448, 448), dtype=torch.float16),
torch.rand((2, 16, 32, 32), dtype=torch.float16),
torch.rand((2, 2)).to(torch.int64),
),
{},
),
}
model = create_onnx_model_from_input_tensors(data)
shapes = [
tuple(d.dim_value for d in i.type.tensor_type.shape.dim)
for i in model.graph.output
]
self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)])
names = [i.name for i in model.graph.output]
self.assertEqual(
[
"dict._((amain,0,I))___tuple_0___tuple_0___tensor",
"dict._((amain,0,I))___tuple_0___tuple_1___tensor",
"dict._((amain,0,I))___tuple_0___tuple_2.___tensor",
"dict._((amain,0,I))___tuple_1.___dict.___empty",
],
names,
)
shapes = [tuple(i.dims) for i in model.graph.initializer]
self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)])
names = [i.name for i in model.graph.initializer]
self.assertEqual(
[
"t_dict._((amain,0,I))___tuple_0___tuple_0___tensor",
"t_dict._((amain,0,I))___tuple_0___tuple_1___tensor",
"t_dict._((amain,0,I))___tuple_0___tuple_2.___tensor",
"t_dict._((amain,0,I))___tuple_1.___dict.___empty",
],
names,
)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(data, restored)


if __name__ == "__main__":
unittest.main(verbosity=2)
46 changes: 46 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.torch_export_patches.patch_expressions import (
_iterate_patched_expressions,
register_patched_expressions,
patched_selector,
patched_float_arange,
)
from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting


class TestOnnxExportErrors(ExtTestCase):

@classmethod
def setUp(cls):
register_patched_expressions()

def test_patched_expressions(self):
res = list(_iterate_patched_expressions())
names = {_[0] for _ in res}
self.assertIn("float_arange", names)

def test_float_arange(self):
_T = torch.tensor
res = torch.arange(4, 6, 0.234)
got = torch.arange(4, 6, 0.234, dtype=torch.float32, device=torch.device("cpu"))
self.assertEqualArray(res, got)
got = torch.ops.patched.float_arange(_T(4.0), _T(6.0), _T(0.234))
self.assertEqualArray(res, got, atol=1e-5)
got = patched_selector(
(lambda a, b, c: torch.arange(a.item(), b.item(), c.item())),
torch.ops.patched.float_arange,
)(_T(4.0), _T(6.0), _T(0.234))
self.assertEqualArray(res, got, atol=1e-5)
got = patched_float_arange(_T(4.0), _T(6.0), _T(0.234))
self.assertEqualArray(res, got, atol=1e-5)
with fake_torchdynamo_exporting():
got = patched_selector(None, torch.ops.patched.float_arange)(
_T(4.0), _T(6.0), _T(0.234)
)
self.assertEqualArray(res, got, atol=1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)
124 changes: 124 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.helpers.torch_test_helper import (
is_torchdynamo_exporting,
fake_torchdynamo_exporting,
)
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.patch_expressions import (
_iterate_patched_expressions,
register_patched_expressions,
patched_float_arange,
)


class TestOnnxExportErrors(ExtTestCase):

def test_patched_expressions(self):
res = list(_iterate_patched_expressions())
names = {_[0] for _ in res}
self.assertIn("float_arange", names)

@requires_torch("2.8")
def test_filter_position_ids(self):

def filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
boundaries: torch.Tensor,
num_patches_per_side: int,
):
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum())
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum())

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
return position_ids

def float_arange(start, end, step):
length = torch.sym_int((end - start) / step + (step * (1 - 1e-6)))
torch._check(length > 0)
res = torch.arange(0, length)
torch._check(res.is_contiguous())
fres = res.to(torch.float32)
fstart = torch.tensor(start, dtype=torch.float32)
return fres + fstart

def scan_filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
boundaries: torch.Tensor,
num_patches_per_side: int,
):

def body(p_attn_mask, position_ids_row):
h_len = torch.tensor(1) / p_attn_mask[:, 0].sum()
w_len = torch.tensor(1) / p_attn_mask[0].sum()
fractional_coords_h = patched_float_arange(
torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len
)
fractional_coords_w = patched_float_arange(
torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len
)

# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item())
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item())

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
).flatten()

row = position_ids_row.clone()
row[p_attn_mask.view(-1)] = pos_ids
return [row]

return torch.ops.higher_order.scan(
body, [], [patch_attention_mask, position_ids], additional_inputs=[]
)

class Model(torch.nn.Module):
def forward(self, patch_attention_mask, position_ids, boundaries):
if is_torchdynamo_exporting():
res = scan_filter_position_ids(
patch_attention_mask, position_ids, boundaries, 32
)
return res[0]
return filter_position_ids(patch_attention_mask, position_ids, boundaries, 32)

# 32
# T9s32x32x32[False,True:A0.978515625],
# T7s32x1024[0,0:A0.0],
# T1s31[0.03125,0.96875:A0.5]]
register_patched_expressions()
patch_attention_mask = torch.randint(0, 20, (32, 32, 32)) >= 1
patch_attention_mask[:, :, :] = True
position_ids = torch.zeros((32, 1024), dtype=torch.int64)
boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1]
inputs = (patch_attention_mask, position_ids, boundaries)
model = Model()
expected = model(*inputs)
with fake_torchdynamo_exporting():
got = model(*inputs)
self.assertEqual(type(expected), type(got))
self.assertEqual(
string_type(expected, with_shape=True), string_type(got, with_shape=True)
)
self.assertEqualArray(expected, got)

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}))
self.assertEqualArray(expected, ep.module()(*inputs))


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .helper import max_diff, string_diff, string_sig, string_type
from .helper import flatten_object, max_diff, string_diff, string_sig, string_type
11 changes: 7 additions & 4 deletions onnx_diagnostic/helpers/mini_onnx_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def append_output_initializer(
return

init_name = f"t_{name}"
assert (
init_name not in self.initializers_dict
), f"name={init_name!r} already in {sorted(self.initializers_dict)}"
self.initializers_dict[init_name] = tensor
shape = tuple(map(int, tensor.shape))
self.outputs.append(
Expand Down Expand Up @@ -324,21 +327,21 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
for i, o in enumerate(obj):
if i == len(obj) - 1:
for p, oo in _flatten_iterator(o, sep):
yield f"tuple.{sep}{p}", oo
yield f"tuple_{i}.{sep}{p}", oo
else:
for p, oo in _flatten_iterator(o, sep):
yield f"tuple{sep}{p}", oo
yield f"tuple_{i}{sep}{p}", oo
elif isinstance(obj, list):
if not obj:
yield f"list.{sep}empty", None
else:
for i, o in enumerate(obj):
if i == len(obj) - 1:
for p, oo in _flatten_iterator(o, sep):
yield f"list.{sep}{p}", oo
yield f"list_{i}.{sep}{p}", oo
else:
for p, oo in _flatten_iterator(o, sep):
yield f"list{sep}{p}", oo
yield f"list_{i}{sep}{p}", oo
elif isinstance(obj, dict):
if not obj:
yield f"dict.{sep}empty", None
Expand Down
26 changes: 21 additions & 5 deletions onnx_diagnostic/helpers/ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,27 @@ def __init__(
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
raise ValueError(f"Unexpected value for providers={providers!r}")
sess = onnxruntime.InferenceSession(
sess if isinstance(sess, str) else sess.SerializeToString(),
session_options,
providers=providers,
)
try:
sess = onnxruntime.InferenceSession(
sess if isinstance(sess, str) else sess.SerializeToString(),
session_options,
providers=providers,
)
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
if isinstance(sess, onnx.ModelProto):
debug_path = "_debug_onnxruntine_evaluator_failure.onnx"
onnx.save(
sess,
debug_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
else:
debug_path = sess
raise RuntimeError(
f"Unable to create a session stored in {debug_path!r}), "
f"providers={providers}"
) from e
else:
assert (
session_options is None
Expand Down
18 changes: 16 additions & 2 deletions onnx_diagnostic/helpers/torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,25 @@ def forward(self, x, y):
onnx.save(
proto,
dump_file,
save_as_external_data=False,
all_tensors_to_one_file=True,
save_as_external_data=True,
all_tensors_to_one_file=False,
)


@contextlib.contextmanager
def fake_torchdynamo_exporting():
"""
Sets ``torch.compiler._is_exporting_flag`` to True to trigger
pieces of code only enabled during export.
"""
memorize = torch.compiler._is_exporting_flag
torch.compiler._is_exporting_flag = True
try:
yield
finally:
torch.compiler._is_exporting_flag = memorize


def is_torchdynamo_exporting() -> bool:
"""
Tells if :epkg:`torch` is exporting a model.
Expand Down
Loading
Loading