Skip to content

Commit b19dcd4

Browse files
authored
Introduces patched expression to get around annoying export issues (#93)
* Add patched expression * doc * 2.8 * fix a few things * fix bug
1 parent 600a02f commit b19dcd4

File tree

13 files changed

+377
-13
lines changed

13 files changed

+377
-13
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`93`: introduce patched expression to get around annoying export issues
78
* :pr:`92`: support errors distribution in max_diff
89
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
910
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models

_doc/api/helpers/helper.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ onnx_diagnostic.helpers.helper
44

55
.. automodule:: onnx_diagnostic.helpers.helper
66
:no-undoc-members:
7-
:exclude-members: max_diff, string_diff, string_sig, string_type
7+
:exclude-members: flatten_object, max_diff, string_diff, string_sig, string_type

_doc/api/helpers/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ onnx_diagnostic.helpers
1818
rt_helper
1919
torch_test_helper
2020

21+
.. autofunction:: onnx_diagnostic.helpers.flatten_object
22+
2123
.. autofunction:: onnx_diagnostic.helpers.max_diff
2224

2325
.. autofunction:: onnx_diagnostic.helpers.string_diff

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ onnx_diagnostic.torch_export_patches
66
:caption: submodules
77

88
patches/index
9+
patch_expressions
910
patch_inputs
1011
patch_module
1112

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_expressions
3+
======================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_expressions
6+
:members:
7+
:no-undoc-members:

_unittests/ut_helpers/test_mini_onnx_builder.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,48 @@ def test_mini_onnx_builder_transformers_sep(self):
151151
restored = create_input_tensors_from_onnx_model(model, sep="#")
152152
self.assertEqualAny(inputs, restored)
153153

154+
def test_specific_data(self):
155+
data = {
156+
("amain", 0, "I"): (
157+
(
158+
torch.rand((2, 16, 3, 448, 448), dtype=torch.float16),
159+
torch.rand((2, 16, 32, 32), dtype=torch.float16),
160+
torch.rand((2, 2)).to(torch.int64),
161+
),
162+
{},
163+
),
164+
}
165+
model = create_onnx_model_from_input_tensors(data)
166+
shapes = [
167+
tuple(d.dim_value for d in i.type.tensor_type.shape.dim)
168+
for i in model.graph.output
169+
]
170+
self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)])
171+
names = [i.name for i in model.graph.output]
172+
self.assertEqual(
173+
[
174+
"dict._((amain,0,I))___tuple_0___tuple_0___tensor",
175+
"dict._((amain,0,I))___tuple_0___tuple_1___tensor",
176+
"dict._((amain,0,I))___tuple_0___tuple_2.___tensor",
177+
"dict._((amain,0,I))___tuple_1.___dict.___empty",
178+
],
179+
names,
180+
)
181+
shapes = [tuple(i.dims) for i in model.graph.initializer]
182+
self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)])
183+
names = [i.name for i in model.graph.initializer]
184+
self.assertEqual(
185+
[
186+
"t_dict._((amain,0,I))___tuple_0___tuple_0___tensor",
187+
"t_dict._((amain,0,I))___tuple_0___tuple_1___tensor",
188+
"t_dict._((amain,0,I))___tuple_0___tuple_2.___tensor",
189+
"t_dict._((amain,0,I))___tuple_1.___dict.___empty",
190+
],
191+
names,
192+
)
193+
restored = create_input_tensors_from_onnx_model(model)
194+
self.assertEqualAny(data, restored)
195+
154196

155197
if __name__ == "__main__":
156198
unittest.main(verbosity=2)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.torch_export_patches.patch_expressions import (
5+
_iterate_patched_expressions,
6+
register_patched_expressions,
7+
patched_selector,
8+
patched_float_arange,
9+
)
10+
from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting
11+
12+
13+
class TestOnnxExportErrors(ExtTestCase):
14+
15+
@classmethod
16+
def setUp(cls):
17+
register_patched_expressions()
18+
19+
def test_patched_expressions(self):
20+
res = list(_iterate_patched_expressions())
21+
names = {_[0] for _ in res}
22+
self.assertIn("float_arange", names)
23+
24+
def test_float_arange(self):
25+
_T = torch.tensor
26+
res = torch.arange(4, 6, 0.234)
27+
got = torch.arange(4, 6, 0.234, dtype=torch.float32, device=torch.device("cpu"))
28+
self.assertEqualArray(res, got)
29+
got = torch.ops.patched.float_arange(_T(4.0), _T(6.0), _T(0.234))
30+
self.assertEqualArray(res, got, atol=1e-5)
31+
got = patched_selector(
32+
(lambda a, b, c: torch.arange(a.item(), b.item(), c.item())),
33+
torch.ops.patched.float_arange,
34+
)(_T(4.0), _T(6.0), _T(0.234))
35+
self.assertEqualArray(res, got, atol=1e-5)
36+
got = patched_float_arange(_T(4.0), _T(6.0), _T(0.234))
37+
self.assertEqualArray(res, got, atol=1e-5)
38+
with fake_torchdynamo_exporting():
39+
got = patched_selector(None, torch.ops.patched.float_arange)(
40+
_T(4.0), _T(6.0), _T(0.234)
41+
)
42+
self.assertEqualArray(res, got, atol=1e-5)
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main(verbosity=2)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
4+
from onnx_diagnostic.helpers.torch_test_helper import (
5+
is_torchdynamo_exporting,
6+
fake_torchdynamo_exporting,
7+
)
8+
from onnx_diagnostic.helpers import string_type
9+
from onnx_diagnostic.torch_export_patches.patch_expressions import (
10+
_iterate_patched_expressions,
11+
register_patched_expressions,
12+
patched_float_arange,
13+
)
14+
15+
16+
class TestOnnxExportErrors(ExtTestCase):
17+
18+
def test_patched_expressions(self):
19+
res = list(_iterate_patched_expressions())
20+
names = {_[0] for _ in res}
21+
self.assertIn("float_arange", names)
22+
23+
@requires_torch("2.8")
24+
def test_filter_position_ids(self):
25+
26+
def filter_position_ids(
27+
patch_attention_mask: torch.Tensor,
28+
position_ids: torch.Tensor,
29+
boundaries: torch.Tensor,
30+
num_patches_per_side: int,
31+
):
32+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
33+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum())
34+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum())
35+
36+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
37+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
38+
39+
pos_ids = (
40+
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
41+
).flatten()
42+
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
43+
return position_ids
44+
45+
def float_arange(start, end, step):
46+
length = torch.sym_int((end - start) / step + (step * (1 - 1e-6)))
47+
torch._check(length > 0)
48+
res = torch.arange(0, length)
49+
torch._check(res.is_contiguous())
50+
fres = res.to(torch.float32)
51+
fstart = torch.tensor(start, dtype=torch.float32)
52+
return fres + fstart
53+
54+
def scan_filter_position_ids(
55+
patch_attention_mask: torch.Tensor,
56+
position_ids: torch.Tensor,
57+
boundaries: torch.Tensor,
58+
num_patches_per_side: int,
59+
):
60+
61+
def body(p_attn_mask, position_ids_row):
62+
h_len = torch.tensor(1) / p_attn_mask[:, 0].sum()
63+
w_len = torch.tensor(1) / p_attn_mask[0].sum()
64+
fractional_coords_h = patched_float_arange(
65+
torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len
66+
)
67+
fractional_coords_w = patched_float_arange(
68+
torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len
69+
)
70+
71+
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item())
72+
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item())
73+
74+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
75+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
76+
77+
pos_ids = (
78+
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
79+
).flatten()
80+
81+
row = position_ids_row.clone()
82+
row[p_attn_mask.view(-1)] = pos_ids
83+
return [row]
84+
85+
return torch.ops.higher_order.scan(
86+
body, [], [patch_attention_mask, position_ids], additional_inputs=[]
87+
)
88+
89+
class Model(torch.nn.Module):
90+
def forward(self, patch_attention_mask, position_ids, boundaries):
91+
if is_torchdynamo_exporting():
92+
res = scan_filter_position_ids(
93+
patch_attention_mask, position_ids, boundaries, 32
94+
)
95+
return res[0]
96+
return filter_position_ids(patch_attention_mask, position_ids, boundaries, 32)
97+
98+
# 32
99+
# T9s32x32x32[False,True:A0.978515625],
100+
# T7s32x1024[0,0:A0.0],
101+
# T1s31[0.03125,0.96875:A0.5]]
102+
register_patched_expressions()
103+
patch_attention_mask = torch.randint(0, 20, (32, 32, 32)) >= 1
104+
patch_attention_mask[:, :, :] = True
105+
position_ids = torch.zeros((32, 1024), dtype=torch.int64)
106+
boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1]
107+
inputs = (patch_attention_mask, position_ids, boundaries)
108+
model = Model()
109+
expected = model(*inputs)
110+
with fake_torchdynamo_exporting():
111+
got = model(*inputs)
112+
self.assertEqual(type(expected), type(got))
113+
self.assertEqual(
114+
string_type(expected, with_shape=True), string_type(got, with_shape=True)
115+
)
116+
self.assertEqualArray(expected, got)
117+
118+
DYN = torch.export.Dim.DYNAMIC
119+
ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}))
120+
self.assertEqualArray(expected, ep.module()(*inputs))
121+
122+
123+
if __name__ == "__main__":
124+
unittest.main(verbosity=2)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .helper import max_diff, string_diff, string_sig, string_type
1+
from .helper import flatten_object, max_diff, string_diff, string_sig, string_type

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def append_output_initializer(
139139
return
140140

141141
init_name = f"t_{name}"
142+
assert (
143+
init_name not in self.initializers_dict
144+
), f"name={init_name!r} already in {sorted(self.initializers_dict)}"
142145
self.initializers_dict[init_name] = tensor
143146
shape = tuple(map(int, tensor.shape))
144147
self.outputs.append(
@@ -324,21 +327,21 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
324327
for i, o in enumerate(obj):
325328
if i == len(obj) - 1:
326329
for p, oo in _flatten_iterator(o, sep):
327-
yield f"tuple.{sep}{p}", oo
330+
yield f"tuple_{i}.{sep}{p}", oo
328331
else:
329332
for p, oo in _flatten_iterator(o, sep):
330-
yield f"tuple{sep}{p}", oo
333+
yield f"tuple_{i}{sep}{p}", oo
331334
elif isinstance(obj, list):
332335
if not obj:
333336
yield f"list.{sep}empty", None
334337
else:
335338
for i, o in enumerate(obj):
336339
if i == len(obj) - 1:
337340
for p, oo in _flatten_iterator(o, sep):
338-
yield f"list.{sep}{p}", oo
341+
yield f"list_{i}.{sep}{p}", oo
339342
else:
340343
for p, oo in _flatten_iterator(o, sep):
341-
yield f"list{sep}{p}", oo
344+
yield f"list_{i}{sep}{p}", oo
342345
elif isinstance(obj, dict):
343346
if not obj:
344347
yield f"dict.{sep}empty", None

0 commit comments

Comments
 (0)