Skip to content

Commit e52bc52

Browse files
committed
Add patched expression
1 parent 600a02f commit e52bc52

File tree

4 files changed

+292
-0
lines changed

4 files changed

+292
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.torch_export_patches.patch_expressions import (
5+
_iterate_patched_expressions,
6+
register_patched_expressions,
7+
patched_selector,
8+
patched_float_arange,
9+
)
10+
from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting
11+
12+
13+
class TestOnnxExportErrors(ExtTestCase):
14+
15+
@classmethod
16+
def setUp(cls):
17+
register_patched_expressions()
18+
19+
def test_patched_expressions(self):
20+
res = list(_iterate_patched_expressions())
21+
names = {_[0] for _ in res}
22+
self.assertIn("float_arange", names)
23+
24+
def test_float_arange(self):
25+
_T = torch.tensor
26+
res = torch.arange(4, 6, 0.234)
27+
got = torch.arange(4, 6, 0.234, dtype=torch.float32, device=torch.device("cpu"))
28+
self.assertEqualArray(res, got)
29+
got = torch.ops.patched.float_arange(_T(4.0), _T(6.0), _T(0.234))
30+
self.assertEqualArray(res, got, atol=1e-5)
31+
got = patched_selector(
32+
(lambda a, b, c: torch.arange(a.item(), b.item(), c.item())),
33+
torch.ops.patched.float_arange,
34+
)(_T(4.0), _T(6.0), _T(0.234))
35+
self.assertEqualArray(res, got, atol=1e-5)
36+
got = patched_float_arange(_T(4.0), _T(6.0), _T(0.234))
37+
self.assertEqualArray(res, got, atol=1e-5)
38+
with fake_torchdynamo_exporting():
39+
got = patched_selector(None, torch.ops.patched.float_arange)(
40+
_T(4.0), _T(6.0), _T(0.234)
41+
)
42+
self.assertEqualArray(res, got, atol=1e-5)
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main(verbosity=2)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
4+
from onnx_diagnostic.helpers.torch_test_helper import (
5+
is_torchdynamo_exporting,
6+
fake_torchdynamo_exporting,
7+
)
8+
from onnx_diagnostic.helpers import string_type
9+
from onnx_diagnostic.torch_export_patches.patch_expressions import (
10+
_iterate_patched_expressions,
11+
register_patched_expressions,
12+
patched_float_arange,
13+
)
14+
15+
16+
class TestOnnxExportErrors(ExtTestCase):
17+
18+
def test_patched_expressions(self):
19+
res = list(_iterate_patched_expressions())
20+
names = {_[0] for _ in res}
21+
self.assertIn("float_arange", names)
22+
23+
@requires_torch("2.7")
24+
def test_filter_position_ids(self):
25+
26+
def filter_position_ids(
27+
patch_attention_mask: torch.Tensor,
28+
position_ids: torch.Tensor,
29+
boundaries: torch.Tensor,
30+
num_patches_per_side: int,
31+
):
32+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
33+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum())
34+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum())
35+
36+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
37+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
38+
39+
pos_ids = (
40+
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
41+
).flatten()
42+
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
43+
return position_ids
44+
45+
def float_arange(start, end, step):
46+
length = torch.sym_int((end - start) / step + (step * (1 - 1e-6)))
47+
torch._check(length > 0)
48+
res = torch.arange(0, length)
49+
torch._check(res.is_contiguous())
50+
fres = res.to(torch.float32)
51+
fstart = torch.tensor(start, dtype=torch.float32)
52+
return fres + fstart
53+
54+
def scan_filter_position_ids(
55+
patch_attention_mask: torch.Tensor,
56+
position_ids: torch.Tensor,
57+
boundaries: torch.Tensor,
58+
num_patches_per_side: int,
59+
):
60+
61+
def body(p_attn_mask, position_ids_row):
62+
h_len = torch.tensor(1) / p_attn_mask[:, 0].sum()
63+
w_len = torch.tensor(1) / p_attn_mask[0].sum()
64+
fractional_coords_h = patched_float_arange(
65+
torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len
66+
)
67+
fractional_coords_w = patched_float_arange(
68+
torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len
69+
)
70+
71+
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item())
72+
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item())
73+
74+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
75+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
76+
77+
pos_ids = (
78+
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
79+
).flatten()
80+
81+
row = position_ids_row.clone()
82+
row[p_attn_mask.view(-1)] = pos_ids
83+
return [row]
84+
85+
return torch.ops.higher_order.scan(
86+
body, [], [patch_attention_mask, position_ids], additional_inputs=[]
87+
)
88+
89+
class Model(torch.nn.Module):
90+
def forward(self, patch_attention_mask, position_ids, boundaries):
91+
if is_torchdynamo_exporting():
92+
res = scan_filter_position_ids(
93+
patch_attention_mask, position_ids, boundaries, 32
94+
)
95+
return res[0]
96+
return filter_position_ids(patch_attention_mask, position_ids, boundaries, 32)
97+
98+
# 32
99+
# T9s32x32x32[False,True:A0.978515625],
100+
# T7s32x1024[0,0:A0.0],
101+
# T1s31[0.03125,0.96875:A0.5]]
102+
register_patched_expressions()
103+
patch_attention_mask = torch.randint(0, 20, (32, 32, 32)) >= 1
104+
patch_attention_mask[:, :, :] = True
105+
position_ids = torch.zeros((32, 1024), dtype=torch.int64)
106+
boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1]
107+
inputs = (patch_attention_mask, position_ids, boundaries)
108+
model = Model()
109+
expected = model(*inputs)
110+
with fake_torchdynamo_exporting():
111+
got = model(*inputs)
112+
self.assertEqual(type(expected), type(got))
113+
self.assertEqual(
114+
string_type(expected, with_shape=True), string_type(got, with_shape=True)
115+
)
116+
self.assertEqualArray(expected, got)
117+
118+
DYN = torch.export.Dim.DYNAMIC
119+
ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}))
120+
self.assertEqualArray(expected, ep.module()(*inputs))
121+
122+
123+
if __name__ == "__main__":
124+
unittest.main(verbosity=2)

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,20 @@ def forward(self, x, y):
156156
)
157157

158158

159+
@contextlib.contextmanager
160+
def fake_torchdynamo_exporting():
161+
"""
162+
Sets ``torch.compiler._is_exporting_flag`` to True to trigger
163+
pieces of code only enabled during export.
164+
"""
165+
memorize = torch.compiler._is_exporting_flag
166+
torch.compiler._is_exporting_flag = True
167+
try:
168+
yield
169+
finally:
170+
torch.compiler._is_exporting_flag = memorize
171+
172+
159173
def is_torchdynamo_exporting() -> bool:
160174
"""
161175
Tells if :epkg:`torch` is exporting a model.
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Callable, Set
2+
import torch
3+
from ..helpers.torch_test_helper import is_torchdynamo_exporting
4+
5+
6+
def make_undefined_dimension(i: int) -> torch.SymInt:
7+
"""
8+
Uses for a custom op when a new dimension must be introduced to bypass
9+
some verification. The following function creates a dummy output
10+
with a dimension based on the content.
11+
12+
.. code-block:: python
13+
14+
def symbolic_shape(x, y):
15+
return torch.empty(
16+
x.shape[0],
17+
make_undefined_dimension(min(x.shape[1], y[0])),
18+
)
19+
"""
20+
try:
21+
ti = int(i)
22+
except: # noqa: E722
23+
ti = 10
24+
t = torch.ones((ti * 2,))
25+
t[:ti] = 0
26+
res = torch.nonzero(t).shape[0]
27+
return res
28+
29+
30+
def _patched_float_arange(
31+
start: torch.Tensor, end: torch.Tensor, step: torch.Tensor
32+
) -> torch.Tensor:
33+
"""Float arange."""
34+
return torch.arange(
35+
float(start.item()),
36+
float(end.item()),
37+
float(step.item()),
38+
dtype=start.dtype,
39+
device=start.device,
40+
)
41+
42+
43+
def _patched_float_arange_shape(start, end, step):
44+
# Fails because:
45+
# Did you accidentally call new_dynamic_size() or item()
46+
# more times than you needed to in your fake implementation?
47+
# try:
48+
# n = math.ceil(((end - start) / step).item())
49+
# except: # noqa: E722
50+
# n = 10
51+
n = 10
52+
return torch.empty((make_undefined_dimension(n),), dtype=start.dtype, device=start.device)
53+
54+
55+
def _iterate_patched_expressions():
56+
glo = globals().copy()
57+
for k, _v in glo.items():
58+
if k.startswith("_patched_") and not k.endswith("_shape"):
59+
name = k
60+
yield k[len("_patched_") :], glo[name], glo[f"{name}_shape"]
61+
62+
63+
_registered: Set[str] = set()
64+
65+
66+
def _register_patched_expression(
67+
fct: Callable, fct_shape: Callable, namespace: str, fname: str
68+
):
69+
schema_str = torch.library.infer_schema(fct, mutates_args=())
70+
custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
71+
custom_def.register_kernel("cpu")(fct)
72+
custom_def._abstract_fn = fct_shape
73+
74+
75+
def register_patched_expressions(namespace: str = "patched"):
76+
"""
77+
Registers as custom ops known expressions failing due to dynamic shapes.
78+
79+
.. runpython::
80+
:showcode:
81+
82+
import pprint
83+
from onnx_diagnostic.torch_export_patches.patch_expressions import (
84+
_iterate_patched_expressions,
85+
)
86+
87+
pprint.pprint([name for name, _f, _fsh in _iterate_patched_expressions()])
88+
"""
89+
for name, f, fsh in _iterate_patched_expressions():
90+
if name not in _registered:
91+
_register_patched_expression(f, fsh, namespace, name)
92+
_registered.add(name)
93+
94+
95+
def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
96+
"""
97+
Returns **fct** if the model is being executed or
98+
**patched_fct** if it is being exported.
99+
"""
100+
return patched_fct if is_torchdynamo_exporting() else fct
101+
102+
103+
def patched_float_arange(start, end, step):
104+
"""Patched arange when start, end, step are floats."""
105+
if is_torchdynamo_exporting():
106+
return torch.ops.patched.float_arange(start, end, step)
107+
else:
108+
return torch.arange(start, end, step)

0 commit comments

Comments
 (0)