Skip to content

Commit d6cbc6f

Browse files
committed
export
1 parent 2968c2c commit d6cbc6f

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from typing import Callable
23
import torch
34
import transformers
45
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
@@ -18,6 +19,9 @@
1819
convert_dynamic_axes_into_dynamic_shapes,
1920
)
2021
from onnx_diagnostic.torch_export_patches import torch_export_patches
22+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
23+
patched__vmap_for_bhqkv,
24+
)
2125

2226

2327
class TestCacheHelpers(ExtTestCase):
@@ -254,6 +258,73 @@ def test_unflatten_flatten_hybrid_cache(self):
254258
self.string_type(unflat, with_shape=True),
255259
)
256260

261+
def test_cache_update_padding_mask_function(self):
262+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
263+
264+
def causal_mask_function(
265+
batch_idx: int, head_idx: int, q_idx: int, kv_idx: int
266+
) -> bool:
267+
return kv_idx <= q_idx
268+
269+
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
270+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
271+
return padding_mask[batch_idx, kv_idx]
272+
273+
return inner_mask
274+
275+
def and_masks(*mask_functions: list[Callable]) -> Callable:
276+
if not all(callable(arg) for arg in mask_functions):
277+
raise RuntimeError(
278+
f"All inputs should be callable mask_functions: {mask_functions}"
279+
)
280+
281+
def and_mask(batch_idx, head_idx, q_idx, kv_idx):
282+
result = q_idx.new_ones((), dtype=torch.bool)
283+
for mask in mask_functions:
284+
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(
285+
result.device
286+
)
287+
return result
288+
289+
return and_mask
290+
291+
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
292+
dimensions = [(None, None, None, 0), (None, None, 0, None)]
293+
if bh_indices:
294+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
295+
for dims in dimensions:
296+
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
297+
return mask_function
298+
299+
class Model(torch.nn.Module):
300+
def forward(self, x, mask):
301+
mask_function = and_masks(causal_mask_function, padding_mask_function(mask))
302+
batch_arange = torch.arange(x.shape[0])
303+
head_arange = torch.arange(x.shape[3])
304+
kv_arange = torch.arange(x.shape[1])
305+
cache_position = torch.arange(x.shape[2])
306+
with TransformGetItemToIndex():
307+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
308+
batch_arange, head_arange, cache_position, kv_arange
309+
)
310+
return x + causal_mask.to(x.dtype)
311+
312+
inputs = {
313+
"x": torch.rand((4, 4, 4, 4), dtype=torch.float32),
314+
"mask": torch.ones((4, 4), dtype=torch.int64),
315+
}
316+
model = Model()
317+
expected = model(**inputs)
318+
self.assertNotEmpty(expected)
319+
DYN = torch.export.Dim.DYNAMIC
320+
ep = torch.export.export(
321+
model,
322+
(),
323+
kwargs=inputs,
324+
dynamic_shapes={"x": {0: DYN, 1: DYN, 2: DYN, 3: DYN}, "mask": {0: DYN, 1: DYN}},
325+
)
326+
self.assertNotEmpty(ep)
327+
257328

258329
if __name__ == "__main__":
259330
unittest.main(verbosity=2)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(self):
270270
self.num_attention_heads = key_value_pairs[0][0].shape[1]
271271
self.num_hidden_layers = len(key_value_pairs)
272272

273-
def get_text_config(self):
273+
def get_text_config(self, *args, **kwargs):
274274
return self
275275

276276
assert max_cache_len is not None, (
@@ -366,7 +366,7 @@ def __init__(self):
366366
self.num_hidden_layers = len(key_value_pairs)
367367
self.dtype = dtype
368368

369-
def get_text_config(self):
369+
def get_text_config(self, *args, **kwargs):
370370
return self
371371

372372
cache = MambaCache(
@@ -409,7 +409,7 @@ def __init__(self):
409409
self.num_hidden_layers = len(key_value_pairs)
410410
self.sliding_window = key_value_pairs[0][0].shape[2]
411411

412-
def get_text_config(self):
412+
def get_text_config(self, *args, **kwargs):
413413
return self
414414

415415
cache = transformers.cache_utils.SlidingWindowCache(
@@ -577,7 +577,7 @@ class _config:
577577
sliding_window = _sliding_window
578578
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
579579

580-
def get_text_config(self):
580+
def get_text_config(self, *args, **kwargs):
581581
return self
582582

583583
if layer_types:

0 commit comments

Comments
 (0)