|
1 | 1 | import unittest |
| 2 | +from typing import Callable |
2 | 3 | import torch |
3 | 4 | import transformers |
4 | 5 | from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers |
|
18 | 19 | convert_dynamic_axes_into_dynamic_shapes, |
19 | 20 | ) |
20 | 21 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
| 22 | +from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( |
| 23 | + patched__vmap_for_bhqkv, |
| 24 | +) |
21 | 25 |
|
22 | 26 |
|
23 | 27 | class TestCacheHelpers(ExtTestCase): |
@@ -254,6 +258,73 @@ def test_unflatten_flatten_hybrid_cache(self): |
254 | 258 | self.string_type(unflat, with_shape=True), |
255 | 259 | ) |
256 | 260 |
|
| 261 | + def test_cache_update_padding_mask_function(self): |
| 262 | + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex |
| 263 | + |
| 264 | + def causal_mask_function( |
| 265 | + batch_idx: int, head_idx: int, q_idx: int, kv_idx: int |
| 266 | + ) -> bool: |
| 267 | + return kv_idx <= q_idx |
| 268 | + |
| 269 | + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: |
| 270 | + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| 271 | + return padding_mask[batch_idx, kv_idx] |
| 272 | + |
| 273 | + return inner_mask |
| 274 | + |
| 275 | + def and_masks(*mask_functions: list[Callable]) -> Callable: |
| 276 | + if not all(callable(arg) for arg in mask_functions): |
| 277 | + raise RuntimeError( |
| 278 | + f"All inputs should be callable mask_functions: {mask_functions}" |
| 279 | + ) |
| 280 | + |
| 281 | + def and_mask(batch_idx, head_idx, q_idx, kv_idx): |
| 282 | + result = q_idx.new_ones((), dtype=torch.bool) |
| 283 | + for mask in mask_functions: |
| 284 | + result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to( |
| 285 | + result.device |
| 286 | + ) |
| 287 | + return result |
| 288 | + |
| 289 | + return and_mask |
| 290 | + |
| 291 | + def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: |
| 292 | + dimensions = [(None, None, None, 0), (None, None, 0, None)] |
| 293 | + if bh_indices: |
| 294 | + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) |
| 295 | + for dims in dimensions: |
| 296 | + mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) |
| 297 | + return mask_function |
| 298 | + |
| 299 | + class Model(torch.nn.Module): |
| 300 | + def forward(self, x, mask): |
| 301 | + mask_function = and_masks(causal_mask_function, padding_mask_function(mask)) |
| 302 | + batch_arange = torch.arange(x.shape[0]) |
| 303 | + head_arange = torch.arange(x.shape[3]) |
| 304 | + kv_arange = torch.arange(x.shape[1]) |
| 305 | + cache_position = torch.arange(x.shape[2]) |
| 306 | + with TransformGetItemToIndex(): |
| 307 | + causal_mask = patched__vmap_for_bhqkv(mask_function)( |
| 308 | + batch_arange, head_arange, cache_position, kv_arange |
| 309 | + ) |
| 310 | + return x + causal_mask.to(x.dtype) |
| 311 | + |
| 312 | + inputs = { |
| 313 | + "x": torch.rand((4, 4, 4, 4), dtype=torch.float32), |
| 314 | + "mask": torch.ones((4, 4), dtype=torch.int64), |
| 315 | + } |
| 316 | + model = Model() |
| 317 | + expected = model(**inputs) |
| 318 | + self.assertNotEmpty(expected) |
| 319 | + DYN = torch.export.Dim.DYNAMIC |
| 320 | + ep = torch.export.export( |
| 321 | + model, |
| 322 | + (), |
| 323 | + kwargs=inputs, |
| 324 | + dynamic_shapes={"x": {0: DYN, 1: DYN, 2: DYN, 3: DYN}, "mask": {0: DYN, 1: DYN}}, |
| 325 | + ) |
| 326 | + self.assertNotEmpty(ep) |
| 327 | + |
257 | 328 |
|
258 | 329 | if __name__ == "__main__": |
259 | 330 | unittest.main(verbosity=2) |
0 commit comments