Skip to content

Commit a72398f

Browse files
authored
patches used sdpa_mask_recent_torch (#218)
* documentation * lint * export * fix patch for sdpa mask * fix import issues * fixes * disable a test * fix * other series of fixes * fix patches * fix * another fix * fix * another try * oblib=vious * fix tests
1 parent 05c3909 commit a72398f

31 files changed

+337
-75
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.10
55
++++++
66

7+
* :pr:`218`: patches used sdpa_mask_recent_torch used from _vmap_for_bhqkv
8+
79
0.7.9
810
+++++
911

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Getting started
5858

5959
git clone https://github.com/sdpython/onnx-diagnostic.git
6060
cd onnx-diagnostic
61-
pip install -e .
61+
pip install -e . -v
6262

6363
or
6464

_doc/examples/plot_dump_intermediate_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@
129129
# Let's create the ONNX model.
130130

131131
ep = torch.export.export(model, inputs, dynamic_shapes=ds)
132-
epo = torch.onnx.export(ep, dynamo=True)
132+
epo = torch.onnx.export(ep)
133133
epo.optimize()
134134
epo.save("plot_dump_intermediate_results.onnx")
135135

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126

127127
with torch_export_patches(patch_transformers=True):
128128
epo = torch.onnx.export(
129-
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
129+
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes
130130
)
131131

132132
# %%

_unittests/ut_export/test_jit.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
ignore_warnings,
88
requires_onnxscript,
99
)
10-
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
1110
from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting
1211

1312
try:
@@ -62,7 +61,7 @@ def test_dummy_loop(self):
6261

6362
@hide_stdout()
6463
@ignore_warnings(UserWarning)
65-
@requires_onnxscript("0.5")
64+
@requires_onnxscript("0.7")
6665
def test_export_loop_onnxscript(self):
6766
class Model(torch.nn.Module):
6867
def forward(self, images, position):
@@ -75,19 +74,6 @@ def forward(self, images, position):
7574
y = torch.arange(5, dtype=torch.int64) + 1
7675
expected = model(x, y)
7776

78-
name = self.get_dump_file("test_export_loop_onnxscript.onnx")
79-
torch.onnx.export(
80-
model,
81-
(x, y),
82-
name,
83-
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
84-
dynamo=False,
85-
)
86-
ref = ExtendedReferenceEvaluator(name)
87-
feeds = dict(images=x.numpy(), position=y.numpy())
88-
got = ref.run(None, feeds)[0]
89-
self.assertEqualArray(expected, got)
90-
9177
DYN = torch.export.Dim.DYNAMIC
9278
ep = torch.export.export(
9379
model,
@@ -103,7 +89,6 @@ def forward(self, images, position):
10389
(x, y),
10490
name2,
10591
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
106-
dynamo=True,
10792
fallback=False,
10893
)
10994
import onnxruntime

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from typing import Callable
23
import torch
34
import transformers
45
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
@@ -19,6 +20,13 @@
1920
)
2021
from onnx_diagnostic.torch_export_patches import torch_export_patches
2122

23+
try:
24+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
25+
patched__vmap_for_bhqkv,
26+
)
27+
except ImportError:
28+
patched__vmap_for_bhqkv = None
29+
2230

2331
class TestCacheHelpers(ExtTestCase):
2432
def test_string_type(self):
@@ -69,7 +77,7 @@ def test_replace_by(self):
6977
)
7078

7179
DYN = torch.export.Dim.DYNAMIC
72-
nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
80+
_nargs, _nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
7381
None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes
7482
)
7583
self.assertEqual(dynamic_shapes, nds)
@@ -254,6 +262,92 @@ def test_unflatten_flatten_hybrid_cache(self):
254262
self.string_type(unflat, with_shape=True),
255263
)
256264

265+
@unittest.skipIf(patched__vmap_for_bhqkv is None, "transformers too old")
266+
def test_cache_update_padding_mask_function_vmap(self):
267+
def causal_mask_function(
268+
batch_idx: int, head_idx: int, q_idx: int, kv_idx: int
269+
) -> bool:
270+
return kv_idx <= q_idx
271+
272+
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
273+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
274+
return padding_mask[batch_idx, kv_idx]
275+
276+
return inner_mask
277+
278+
def and_masks(*mask_functions: list[Callable]) -> Callable:
279+
if not all(callable(arg) for arg in mask_functions):
280+
raise RuntimeError(
281+
f"All inputs should be callable mask_functions: {mask_functions}"
282+
)
283+
284+
def and_mask(batch_idx, head_idx, q_idx, kv_idx):
285+
result = q_idx.new_ones((), dtype=torch.bool)
286+
for mask in mask_functions:
287+
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(
288+
result.device
289+
)
290+
return result
291+
292+
return and_mask
293+
294+
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
295+
dimensions = [(None, None, None, 0), (None, None, 0, None)]
296+
if bh_indices:
297+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
298+
for dims in dimensions:
299+
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
300+
return mask_function
301+
302+
class Model(torch.nn.Module):
303+
def forward(self, x, mask):
304+
mask_function = and_masks(causal_mask_function, padding_mask_function(mask))
305+
batch_arange = torch.arange(x.shape[0])
306+
head_arange = torch.arange(x.shape[3])
307+
kv_arange = torch.arange(x.shape[1])
308+
cache_position = torch.arange(x.shape[2])
309+
f = patched__vmap_for_bhqkv(mask_function)
310+
causal_mask = f(batch_arange, head_arange, cache_position, kv_arange)
311+
return x + causal_mask.to(x.dtype)
312+
313+
inputs = {
314+
"x": torch.rand((4, 4, 4, 4), dtype=torch.float32),
315+
"mask": torch.ones((4, 4), dtype=torch.int64),
316+
}
317+
model = Model()
318+
expected = model(**inputs)
319+
self.assertNotEmpty(expected)
320+
DYN = torch.export.Dim.DYNAMIC
321+
ep = torch.export.export(
322+
model,
323+
(),
324+
kwargs=inputs,
325+
dynamic_shapes={"x": {0: DYN, 1: DYN, 2: DYN, 3: DYN}, "mask": {0: DYN, 1: DYN}},
326+
)
327+
self.assertNotEmpty(ep)
328+
329+
def test_simple_indices(self):
330+
class Model(torch.nn.Module):
331+
def forward(self, x, i, j):
332+
return x[i, j]
333+
334+
inputs = (
335+
torch.rand((4, 4), dtype=torch.float32),
336+
torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64),
337+
torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64),
338+
)
339+
model = Model()
340+
expected = model(*inputs)
341+
self.assertEqual(expected.shape, (4, 4, 4, 4))
342+
DYN = torch.export.Dim.DYNAMIC
343+
sh = {0: DYN, 1: DYN, 2: DYN, 3: DYN}
344+
ep = torch.export.export(
345+
model,
346+
inputs,
347+
dynamic_shapes=({0: DYN, 1: DYN}, sh, sh),
348+
)
349+
self.assertNotEmpty(ep)
350+
257351

258352
if __name__ == "__main__":
259353
unittest.main(verbosity=2)

_unittests/ut_helpers/test_ort_session_tinyllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_check_allruntimes_on_tiny_llm(self):
8787
proto = to_onnx(model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds)
8888
else:
8989
proto = torch.onnx.export(
90-
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds, dynamo=True
90+
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds
9191
).model_proto
9292

9393
self.dump_onnx("test_check_allruntimes_on_tiny_llm.onnx", proto)

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
import unittest
44
from typing import Any, Dict, List, Tuple
55
import torch
6+
7+
try:
8+
import transformers.masking_utils as masking_utils
9+
except ImportError:
10+
masking_utils = None
611
from onnx_diagnostic.ext_test_case import (
712
ExtTestCase,
813
ignore_warnings,
@@ -14,7 +19,9 @@
1419
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
1520
torch_export_patches,
1621
)
22+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1723
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
24+
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
1825

1926

2027
class TestOnnxExportErrors(ExtTestCase):
@@ -305,7 +312,7 @@ def test_phi2_export_module(self):
305312
model,
306313
(),
307314
kwargs=inputs,
308-
dynamic_shapes=dyn_shapes,
315+
dynamic_shapes=use_dyn_not_str(dyn_shapes),
309316
strict=False, # True works but then the it fails during the execution
310317
)
311318
# ep = ep.run_decompositions()
@@ -319,6 +326,7 @@ def test_phi2_export_module(self):
319326

320327
@ignore_warnings(UserWarning)
321328
@requires_torch("2.9")
329+
@hide_stdout()
322330
def test_phi2_export_interpreter(self):
323331
data = get_untrained_model_with_inputs("microsoft/phi-2")
324332
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -338,12 +346,17 @@ def test_phi2_export_interpreter(self):
338346
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
339347
)
340348

341-
with torch_export_patches(patch_transformers=True):
349+
with torch_export_patches(patch_transformers=True, verbose=1):
350+
if masking_utils is not None:
351+
self.assertEqual(
352+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
353+
patch_transformers.patched_sdpa_mask_recent_torch,
354+
)
342355
ep = torch.export.export(
343356
model,
344357
(),
345358
kwargs=inputs,
346-
dynamic_shapes=dyn_shapes,
359+
dynamic_shapes=use_dyn_not_str(dyn_shapes),
347360
strict=False, # True works but then the it fails during the execution
348361
)
349362
# ep = ep.run_decompositions()

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@ def loop_body_1(z, iv, x, y):
604604
rewritten_expected2 = RewrittenModel2()(x, y)
605605
self.assertEqualArray(expected, rewritten_expected2)
606606

607-
if not has_torch("2.9"):
608-
raise unittest.SkipTest("skipped export, torch must be >= 2.9")
607+
if not has_torch("2.10"):
608+
raise unittest.SkipTest("skipped export, torch must be >= 2.10")
609609

610610
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
611611
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False)

_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def forward(self, cache):
164164
def test_base_model_output_unflatten_flatten(self):
165165
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
166166
with torch_export_patches(patch_transformers=True):
167-
flat, _spec = torch.utils._pytree.tree_flatten(bo)
167+
_flat, _spec = torch.utils._pytree.tree_flatten(bo)
168168
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
169169
self.assertIsInstance(unflat, list)
170170
self.assertEqual("#1[T1r3]", self.string_type(unflat))

0 commit comments

Comments
 (0)