Skip to content

Commit 40334c9

Browse files
committed
fix issues
1 parent 8923c26 commit 40334c9

File tree

3 files changed

+116
-3
lines changed

3 files changed

+116
-3
lines changed

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
55
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
66
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
7+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
8+
patched__vmap_for_bhqkv as _vmap_for_bhqkv2,
9+
)
710

811

912
class TestPatchPatchTorch(ExtTestCase):
@@ -62,7 +65,7 @@ def test_vmap_tuple(self):
6265
got = patched_vmap(torch.dot, in_dims=(0, None))(x, y)
6366
self.assertEqualArray(expected, got)
6467

65-
def test_vmap_transformers_scenario(self):
68+
def test_vmap_transformers_scenario_vmap(self):
6669
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
6770
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
6871
return padding_mask[batch_idx, kv_idx]
@@ -127,12 +130,77 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
127130
got = Model()(*inputs)
128131
self.assertEqualArray(causal_mask, got)
129132

133+
if not requires_torch("4.10"):
134+
DYN = torch.export.Dim.DYNAMIC
135+
ds1 = {0: DYN}
136+
ds2 = {0: DYN, 1: DYN}
137+
ds = (ds2, ds1, ds1, ds1)
138+
ep = torch.export.export(Model(), inputs, dynamic_shapes=ds)
139+
self.assertEqualArray(causal_mask, ep.moule(*inputs))
140+
141+
def test_vmap_transformers_scenario_novmap(self):
142+
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
143+
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
144+
return padding_mask[batch_idx, kv_idx]
145+
146+
return inner_mask
147+
148+
def and_masks(*mask_functions: list[Callable]) -> Callable:
149+
def and_mask(batch_idx, head_idx, q_idx, kv_idx):
150+
result = q_idx.new_ones((), dtype=torch.bool)
151+
for mask in mask_functions:
152+
result = result & mask(batch_idx, head_idx, q_idx, kv_idx)
153+
return result
154+
155+
return and_mask
156+
157+
def causal_mask_function(
158+
batch_idx: int, head_idx: int, q_idx: int, kv_idx: int
159+
) -> bool:
160+
return kv_idx <= q_idx
161+
162+
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
163+
dimensions = [(None, None, None, 0), (None, None, 0, None)]
164+
if bh_indices:
165+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
166+
for dims in dimensions:
167+
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
168+
return mask_function
169+
170+
padding_mask = torch.ones((2, 33)).to(torch.bool)
171+
batch_arange = torch.tensor([0, 1], dtype=torch.int64)
172+
head_arange = torch.tensor([0, 1], dtype=torch.int64)
173+
cache_position = torch.tensor([30, 31, 32], dtype=torch.int64)
174+
kv_arange = torch.arange(33, dtype=torch.int64)
175+
mask_function = and_masks(causal_mask_function, padding_mask_function(padding_mask))
176+
with TransformGetItemToIndex():
177+
causal_mask = _vmap_for_bhqkv(mask_function)(
178+
batch_arange, head_arange, cache_position, kv_arange
179+
)
180+
with TransformGetItemToIndex():
181+
causal_mask2 = _vmap_for_bhqkv2(mask_function)(
182+
batch_arange, head_arange, cache_position, kv_arange
183+
)
184+
self.assertEqualArray(causal_mask, causal_mask2)
185+
186+
class Model(torch.nn.Module):
187+
def forward(self, batch_arange, head_arange, cache_position, kv_arange):
188+
with TransformGetItemToIndex():
189+
causal_mask2 = _vmap_for_bhqkv2(mask_function)(
190+
batch_arange, head_arange, cache_position, kv_arange
191+
)
192+
return causal_mask2
193+
194+
inputs = batch_arange, head_arange, cache_position, kv_arange
195+
got = Model()(*inputs)
196+
self.assertEqualArray(causal_mask, got)
197+
130198
DYN = torch.export.Dim.DYNAMIC
131199
ds1 = {0: DYN}
132200
ds2 = {0: DYN, 1: DYN}
133201
ds = (ds2, ds1, ds1, ds1)
134202
ep = torch.export.export(Model(), inputs, dynamic_shapes=ds)
135-
self.assertEqualArray(causal_mask, ep.moule(*inputs))
203+
self.assertEqualArray(causal_mask, ep.module()(*inputs))
136204

137205

138206
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ def torch_export_patches(
334334
####################
335335

336336
if patch_transformers:
337+
import transformers.masking_utils
338+
337339
if verbose:
338340
import transformers
339341

@@ -345,6 +347,16 @@ def torch_export_patches(
345347
patch_transformers_list, verbose=verbose
346348
)
347349

350+
if hasattr(transformers.masking_utils, "_vmap_for_bhqkv"):
351+
if verbose:
352+
print(
353+
"[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv"
354+
)
355+
f_transformers__vmap_for_bhqkv = transformers.masking_utils._vmap_for_bhqkv
356+
transformers.masking_utils._vmap_for_bhqkv = (
357+
patch_transformers_list.patched__vmap_for_bhqkv
358+
)
359+
348360
if custom_patches:
349361
if verbose:
350362
print("[torch_export_patches] applies custom patches")
@@ -443,6 +455,14 @@ def torch_export_patches(
443455
patch_transformers_list, revert_patches_info, verbose=verbose
444456
)
445457

458+
if hasattr(transformers.masking_utils, "_vmap_for_bhqkv"):
459+
if verbose:
460+
print(
461+
"[torch_export_patches] unpatch "
462+
"transformers.masking_utils._vmap_for_bhqkv"
463+
)
464+
transformers.masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
465+
446466
########
447467
# caches
448468
########

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
from dataclasses import dataclass
3-
from typing import Any, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44
import torch
55
import transformers
66
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -9,6 +9,31 @@
99
from ...helpers.torch_helper import is_torchdynamo_exporting
1010

1111

12+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
13+
"""Patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
14+
from ...helpers import string_type
15+
16+
dimensions = [(None, None, None, 0), (None, None, 0, None)]
17+
if bh_indices:
18+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
19+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
20+
dimensions = tuple(reversed(dimensions))
21+
indices = tuple(shape.index(-1) for shape in dimensions)
22+
23+
def vector_mask_function(
24+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
25+
):
26+
assert len(args) == len(
27+
dimensions
28+
), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
29+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
30+
max_shape = tuple(args[i].shape[0] for i in indices)
31+
expanded_args = [a.expand(max_shape) for a in new_args]
32+
return mask_function(*expanded_args)
33+
34+
return vector_mask_function
35+
36+
1237
def _patch_make_causal_mask(
1338
input_ids_shape: torch.Size,
1439
dtype: torch.dtype,

0 commit comments

Comments
 (0)