|
4 | 4 | from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex |
5 | 5 | from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch |
6 | 6 | from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap |
| 7 | +from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( |
| 8 | + patched__vmap_for_bhqkv as _vmap_for_bhqkv2, |
| 9 | +) |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class TestPatchPatchTorch(ExtTestCase): |
@@ -62,7 +65,7 @@ def test_vmap_tuple(self): |
62 | 65 | got = patched_vmap(torch.dot, in_dims=(0, None))(x, y) |
63 | 66 | self.assertEqualArray(expected, got) |
64 | 67 |
|
65 | | - def test_vmap_transformers_scenario(self): |
| 68 | + def test_vmap_transformers_scenario_vmap(self): |
66 | 69 | def padding_mask_function(padding_mask: torch.Tensor) -> Callable: |
67 | 70 | def inner_mask(batch_idx, head_idx, q_idx, kv_idx): |
68 | 71 | return padding_mask[batch_idx, kv_idx] |
@@ -127,12 +130,77 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange): |
127 | 130 | got = Model()(*inputs) |
128 | 131 | self.assertEqualArray(causal_mask, got) |
129 | 132 |
|
| 133 | + if not requires_torch("4.10"): |
| 134 | + DYN = torch.export.Dim.DYNAMIC |
| 135 | + ds1 = {0: DYN} |
| 136 | + ds2 = {0: DYN, 1: DYN} |
| 137 | + ds = (ds2, ds1, ds1, ds1) |
| 138 | + ep = torch.export.export(Model(), inputs, dynamic_shapes=ds) |
| 139 | + self.assertEqualArray(causal_mask, ep.moule(*inputs)) |
| 140 | + |
| 141 | + def test_vmap_transformers_scenario_novmap(self): |
| 142 | + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: |
| 143 | + def inner_mask(batch_idx, head_idx, q_idx, kv_idx): |
| 144 | + return padding_mask[batch_idx, kv_idx] |
| 145 | + |
| 146 | + return inner_mask |
| 147 | + |
| 148 | + def and_masks(*mask_functions: list[Callable]) -> Callable: |
| 149 | + def and_mask(batch_idx, head_idx, q_idx, kv_idx): |
| 150 | + result = q_idx.new_ones((), dtype=torch.bool) |
| 151 | + for mask in mask_functions: |
| 152 | + result = result & mask(batch_idx, head_idx, q_idx, kv_idx) |
| 153 | + return result |
| 154 | + |
| 155 | + return and_mask |
| 156 | + |
| 157 | + def causal_mask_function( |
| 158 | + batch_idx: int, head_idx: int, q_idx: int, kv_idx: int |
| 159 | + ) -> bool: |
| 160 | + return kv_idx <= q_idx |
| 161 | + |
| 162 | + def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: |
| 163 | + dimensions = [(None, None, None, 0), (None, None, 0, None)] |
| 164 | + if bh_indices: |
| 165 | + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) |
| 166 | + for dims in dimensions: |
| 167 | + mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) |
| 168 | + return mask_function |
| 169 | + |
| 170 | + padding_mask = torch.ones((2, 33)).to(torch.bool) |
| 171 | + batch_arange = torch.tensor([0, 1], dtype=torch.int64) |
| 172 | + head_arange = torch.tensor([0, 1], dtype=torch.int64) |
| 173 | + cache_position = torch.tensor([30, 31, 32], dtype=torch.int64) |
| 174 | + kv_arange = torch.arange(33, dtype=torch.int64) |
| 175 | + mask_function = and_masks(causal_mask_function, padding_mask_function(padding_mask)) |
| 176 | + with TransformGetItemToIndex(): |
| 177 | + causal_mask = _vmap_for_bhqkv(mask_function)( |
| 178 | + batch_arange, head_arange, cache_position, kv_arange |
| 179 | + ) |
| 180 | + with TransformGetItemToIndex(): |
| 181 | + causal_mask2 = _vmap_for_bhqkv2(mask_function)( |
| 182 | + batch_arange, head_arange, cache_position, kv_arange |
| 183 | + ) |
| 184 | + self.assertEqualArray(causal_mask, causal_mask2) |
| 185 | + |
| 186 | + class Model(torch.nn.Module): |
| 187 | + def forward(self, batch_arange, head_arange, cache_position, kv_arange): |
| 188 | + with TransformGetItemToIndex(): |
| 189 | + causal_mask2 = _vmap_for_bhqkv2(mask_function)( |
| 190 | + batch_arange, head_arange, cache_position, kv_arange |
| 191 | + ) |
| 192 | + return causal_mask2 |
| 193 | + |
| 194 | + inputs = batch_arange, head_arange, cache_position, kv_arange |
| 195 | + got = Model()(*inputs) |
| 196 | + self.assertEqualArray(causal_mask, got) |
| 197 | + |
130 | 198 | DYN = torch.export.Dim.DYNAMIC |
131 | 199 | ds1 = {0: DYN} |
132 | 200 | ds2 = {0: DYN, 1: DYN} |
133 | 201 | ds = (ds2, ds1, ds1, ds1) |
134 | 202 | ep = torch.export.export(Model(), inputs, dynamic_shapes=ds) |
135 | | - self.assertEqualArray(causal_mask, ep.moule(*inputs)) |
| 203 | + self.assertEqualArray(causal_mask, ep.module()(*inputs)) |
136 | 204 |
|
137 | 205 |
|
138 | 206 | if __name__ == "__main__": |
|
0 commit comments