Skip to content

Commit cf7dc28

Browse files
committed
fix patch for sdpa mask
1 parent d6cbc6f commit cf7dc28

File tree

4 files changed

+162
-14
lines changed

4 files changed

+162
-14
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,7 @@ def test_unflatten_flatten_hybrid_cache(self):
258258
self.string_type(unflat, with_shape=True),
259259
)
260260

261-
def test_cache_update_padding_mask_function(self):
262-
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
263-
261+
def test_cache_update_padding_mask_function_vmap(self):
264262
def causal_mask_function(
265263
batch_idx: int, head_idx: int, q_idx: int, kv_idx: int
266264
) -> bool:
@@ -303,11 +301,9 @@ def forward(self, x, mask):
303301
head_arange = torch.arange(x.shape[3])
304302
kv_arange = torch.arange(x.shape[1])
305303
cache_position = torch.arange(x.shape[2])
306-
with TransformGetItemToIndex():
307-
causal_mask = patched__vmap_for_bhqkv(mask_function)(
308-
batch_arange, head_arange, cache_position, kv_arange
309-
)
310-
return x + causal_mask.to(x.dtype)
304+
f = patched__vmap_for_bhqkv(mask_function)
305+
causal_mask = f(batch_arange, head_arange, cache_position, kv_arange)
306+
return x + causal_mask.to(x.dtype)
311307

312308
inputs = {
313309
"x": torch.rand((4, 4, 4, 4), dtype=torch.float32),
@@ -325,6 +321,28 @@ def forward(self, x, mask):
325321
)
326322
self.assertNotEmpty(ep)
327323

324+
def test_simple_indices(self):
325+
class Model(torch.nn.Module):
326+
def forward(self, x, i, j):
327+
return x[i, j]
328+
329+
inputs = (
330+
torch.rand((4, 4), dtype=torch.float32),
331+
torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64),
332+
torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64),
333+
)
334+
model = Model()
335+
expected = model(*inputs)
336+
self.assertEqual(expected.shape, (4, 4, 4, 4))
337+
DYN = torch.export.Dim.DYNAMIC
338+
sh = {0: DYN, 1: DYN, 2: DYN, 3: DYN}
339+
ep = torch.export.export(
340+
model,
341+
inputs,
342+
dynamic_shapes=({0: DYN, 1: DYN}, sh, sh),
343+
)
344+
self.assertNotEmpty(ep)
345+
328346

329347
if __name__ == "__main__":
330348
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_tiny_llm_export_dynamic(self):
2929
self.assertEqual(
3030
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3131
)
32-
with torch_export_patches(patch_transformers=True):
32+
with torch_export_patches(patch_transformers=True, verbose=1):
3333
ep = torch.export.export(
3434
model,
3535
(),

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,28 @@ def torch_export_patches(
439439
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
440440
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
441441

442+
if verbose:
443+
print(
444+
"[torch_export_patches] patches "
445+
"transformers.masking_utils.sdpa_mask_recent_torch"
446+
)
447+
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
448+
masking_utils.sdpa_mask_recent_torch = (
449+
patch_transformers_list.patched_sdpa_mask_recent_torch
450+
)
451+
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
452+
if verbose:
453+
print(
454+
"[torch_export_patches] patches "
455+
"transformers.masking_utils.sdpa_mask"
456+
)
457+
f_transformers_sdpa_mask = masking_utils.sdpa_mask
458+
masking_utils.sdpa_mask = (
459+
patch_transformers_list.patched_sdpa_mask_recent_torch
460+
)
461+
else:
462+
f_transformers_sdpa_mask = None
463+
442464
if (
443465
masking_utils
444466
and patch_transformers_list.patch_masking_utils
@@ -456,10 +478,37 @@ def torch_export_patches(
456478
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
457479
== f_transformers_eager_mask
458480
):
481+
if verbose:
482+
print(
483+
"[torch_export_patches] patches "
484+
"transformers.masking_utils.eager_mask "
485+
"in ALL_MASK_ATTENTION_FUNCTIONS"
486+
)
459487
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
460488
patch_transformers_list.patched_eager_mask
461489
)
462490

491+
if (
492+
masking_utils
493+
and patch_transformers_list.patch_masking_utils
494+
and hasattr(masking_utils, "sdpa_mask")
495+
and f_transformers_sdpa_mask is not None
496+
):
497+
if verbose:
498+
print(
499+
"[torch_export_patches] patches "
500+
"transformers.masking_utils.sdpa_mask "
501+
"in ALL_MASK_ATTENTION_FUNCTIONS"
502+
)
503+
if (
504+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
505+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
506+
== f_transformers_sdpa_mask
507+
):
508+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
509+
patch_transformers_list.patched_sdpa_mask_recent_torch
510+
)
511+
463512
if custom_patches:
464513
if verbose:
465514
print("[torch_export_patches] applies custom patches")
@@ -568,19 +617,43 @@ def torch_export_patches(
568617
and hasattr(masking_utils, "_vmap_for_bhqkv")
569618
):
570619
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
620+
571621
if verbose:
572622
print(
573623
"[torch_export_patches] restored "
574624
"transformers.masking_utils._vmap_for_bhqkv"
575625
)
576626

627+
masking_utils.sdpa_mask_recent_torch = (
628+
f_transformers_sdpa_mask_recent_torch
629+
)
630+
631+
if verbose:
632+
print(
633+
"[torch_export_patches] restored "
634+
"transformers.masking_utils.sdpa_mask_recent_torch"
635+
)
636+
637+
if f_transformers_sdpa_mask is not None:
638+
masking_utils.sdpa_mask = f_transformers_sdpa_mask
639+
if verbose:
640+
print(
641+
"[torch_export_patches] restored "
642+
"transformers.masking_utils.sdpa_mask"
643+
)
644+
577645
if (
578646
masking_utils
579647
and patch_transformers_list.patch_masking_utils
580648
and hasattr(masking_utils, "eager_mask")
581649
):
582650
f_transformers_eager_mask = masking_utils.eager_mask
583651
masking_utils.eager_mask = f_transformers_eager_mask
652+
if verbose:
653+
print(
654+
"[torch_export_patches] restored "
655+
"transformers.masking_utils.eager_mask"
656+
)
584657
if (
585658
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
586659
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
@@ -589,11 +662,32 @@ def torch_export_patches(
589662
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
590663
f_transformers_eager_mask
591664
)
592-
if verbose:
593-
print(
594-
"[torch_export_patches] restored "
595-
"transformers.masking_utils.eager_mask"
665+
if verbose:
666+
print(
667+
"[torch_export_patches] restored "
668+
"transformers.masking_utils.eager_mask "
669+
"in ALL_MASK_ATTENTION_FUNCTIONS"
670+
)
671+
672+
if (
673+
masking_utils
674+
and patch_transformers_list.patch_masking_utils
675+
and hasattr(masking_utils, "sdpa_mask")
676+
):
677+
if (
678+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
679+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
680+
== patch_transformers_list.patched_sdpa_mask_recent_torch
681+
):
682+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
683+
f_transformers_sdpa_mask
596684
)
685+
if verbose:
686+
print(
687+
"[torch_export_patches] restored "
688+
"transformers.masking_utils.sdpa_mask "
689+
"in ALL_MASK_ATTENTION_FUNCTIONS"
690+
)
597691

598692
########
599693
# caches

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@
3737

3838
if patch_masking_utils:
3939
# Introduced in 4.52
40-
from transformers.masking_utils import causal_mask_function, sdpa_mask
40+
from transformers.masking_utils import (
41+
causal_mask_function,
42+
sdpa_mask,
43+
padding_mask_function,
44+
and_masks,
45+
_ignore_causal_mask_sdpa,
46+
prepare_padding_mask,
47+
)
4148

4249
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
4350
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
@@ -125,6 +132,35 @@ def patched_eager_mask(
125132
mask = (~mask).to(dtype) * min_dtype
126133
return mask
127134

135+
def patched_sdpa_mask_recent_torch(
136+
batch_size: int,
137+
cache_position: torch.Tensor,
138+
kv_length: int,
139+
kv_offset: int = 0,
140+
mask_function: Callable = causal_mask_function,
141+
attention_mask: Optional[torch.Tensor] = None,
142+
local_size: Optional[int] = None,
143+
allow_is_causal_skip: bool = True,
144+
**kwargs,
145+
) -> Optional[torch.Tensor]:
146+
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
147+
q_length = cache_position.shape[0]
148+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
149+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
150+
padding_mask, q_length, kv_length, kv_offset, local_size
151+
):
152+
return None
153+
kv_arange = torch.arange(kv_length, device=cache_position.device)
154+
kv_arange += kv_offset
155+
if padding_mask is not None:
156+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
157+
batch_arange = torch.arange(batch_size, device=cache_position.device)
158+
head_arange = torch.arange(1, device=cache_position.device)
159+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
160+
batch_arange, head_arange, cache_position, kv_arange
161+
)
162+
return causal_mask
163+
128164

129165
if patch_parse_processor_args:
130166

0 commit comments

Comments
 (0)