|
1 | 1 | import unittest |
| 2 | +from typing import Callable |
2 | 3 | import torch |
3 | 4 | import transformers |
4 | 5 | from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers |
|
19 | 20 | ) |
20 | 21 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
21 | 22 |
|
| 23 | +try: |
| 24 | + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( |
| 25 | + patched__vmap_for_bhqkv, |
| 26 | + ) |
| 27 | +except ImportError: |
| 28 | + patched__vmap_for_bhqkv = None |
| 29 | + |
22 | 30 |
|
23 | 31 | class TestCacheHelpers(ExtTestCase): |
24 | 32 | def test_string_type(self): |
@@ -69,7 +77,7 @@ def test_replace_by(self): |
69 | 77 | ) |
70 | 78 |
|
71 | 79 | DYN = torch.export.Dim.DYNAMIC |
72 | | - nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes( |
| 80 | + _nargs, _nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes( |
73 | 81 | None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes |
74 | 82 | ) |
75 | 83 | self.assertEqual(dynamic_shapes, nds) |
@@ -254,6 +262,92 @@ def test_unflatten_flatten_hybrid_cache(self): |
254 | 262 | self.string_type(unflat, with_shape=True), |
255 | 263 | ) |
256 | 264 |
|
| 265 | + @unittest.skipIf(patched__vmap_for_bhqkv is None, "transformers too old") |
| 266 | + def test_cache_update_padding_mask_function_vmap(self): |
| 267 | + def causal_mask_function( |
| 268 | + batch_idx: int, head_idx: int, q_idx: int, kv_idx: int |
| 269 | + ) -> bool: |
| 270 | + return kv_idx <= q_idx |
| 271 | + |
| 272 | + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: |
| 273 | + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| 274 | + return padding_mask[batch_idx, kv_idx] |
| 275 | + |
| 276 | + return inner_mask |
| 277 | + |
| 278 | + def and_masks(*mask_functions: list[Callable]) -> Callable: |
| 279 | + if not all(callable(arg) for arg in mask_functions): |
| 280 | + raise RuntimeError( |
| 281 | + f"All inputs should be callable mask_functions: {mask_functions}" |
| 282 | + ) |
| 283 | + |
| 284 | + def and_mask(batch_idx, head_idx, q_idx, kv_idx): |
| 285 | + result = q_idx.new_ones((), dtype=torch.bool) |
| 286 | + for mask in mask_functions: |
| 287 | + result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to( |
| 288 | + result.device |
| 289 | + ) |
| 290 | + return result |
| 291 | + |
| 292 | + return and_mask |
| 293 | + |
| 294 | + def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: |
| 295 | + dimensions = [(None, None, None, 0), (None, None, 0, None)] |
| 296 | + if bh_indices: |
| 297 | + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) |
| 298 | + for dims in dimensions: |
| 299 | + mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) |
| 300 | + return mask_function |
| 301 | + |
| 302 | + class Model(torch.nn.Module): |
| 303 | + def forward(self, x, mask): |
| 304 | + mask_function = and_masks(causal_mask_function, padding_mask_function(mask)) |
| 305 | + batch_arange = torch.arange(x.shape[0]) |
| 306 | + head_arange = torch.arange(x.shape[3]) |
| 307 | + kv_arange = torch.arange(x.shape[1]) |
| 308 | + cache_position = torch.arange(x.shape[2]) |
| 309 | + f = patched__vmap_for_bhqkv(mask_function) |
| 310 | + causal_mask = f(batch_arange, head_arange, cache_position, kv_arange) |
| 311 | + return x + causal_mask.to(x.dtype) |
| 312 | + |
| 313 | + inputs = { |
| 314 | + "x": torch.rand((4, 4, 4, 4), dtype=torch.float32), |
| 315 | + "mask": torch.ones((4, 4), dtype=torch.int64), |
| 316 | + } |
| 317 | + model = Model() |
| 318 | + expected = model(**inputs) |
| 319 | + self.assertNotEmpty(expected) |
| 320 | + DYN = torch.export.Dim.DYNAMIC |
| 321 | + ep = torch.export.export( |
| 322 | + model, |
| 323 | + (), |
| 324 | + kwargs=inputs, |
| 325 | + dynamic_shapes={"x": {0: DYN, 1: DYN, 2: DYN, 3: DYN}, "mask": {0: DYN, 1: DYN}}, |
| 326 | + ) |
| 327 | + self.assertNotEmpty(ep) |
| 328 | + |
| 329 | + def test_simple_indices(self): |
| 330 | + class Model(torch.nn.Module): |
| 331 | + def forward(self, x, i, j): |
| 332 | + return x[i, j] |
| 333 | + |
| 334 | + inputs = ( |
| 335 | + torch.rand((4, 4), dtype=torch.float32), |
| 336 | + torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64), |
| 337 | + torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64), |
| 338 | + ) |
| 339 | + model = Model() |
| 340 | + expected = model(*inputs) |
| 341 | + self.assertEqual(expected.shape, (4, 4, 4, 4)) |
| 342 | + DYN = torch.export.Dim.DYNAMIC |
| 343 | + sh = {0: DYN, 1: DYN, 2: DYN, 3: DYN} |
| 344 | + ep = torch.export.export( |
| 345 | + model, |
| 346 | + inputs, |
| 347 | + dynamic_shapes=({0: DYN, 1: DYN}, sh, sh), |
| 348 | + ) |
| 349 | + self.assertNotEmpty(ep) |
| 350 | + |
257 | 351 |
|
258 | 352 | if __name__ == "__main__": |
259 | 353 | unittest.main(verbosity=2) |
0 commit comments