|
5 | 5 | import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers |
6 | 6 | from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers |
7 | 7 | from onnx_diagnostic.helpers.torch_helper import torch_deepcopy |
| 8 | +from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions |
| 9 | +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str |
8 | 10 |
|
9 | 11 |
|
10 | 12 | class TestPatchPatchTransformers(ExtTestCase): |
@@ -121,6 +123,57 @@ def test_causal_mask_in_scaled_dot_product_attention(self): |
121 | 123 | attn_causal_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
122 | 124 | self.assertEqual(attn_causal_bias.min().item(), -float("inf")) |
123 | 125 |
|
| 126 | + # @ignore_warnings(UserWarning) |
| 127 | + def test_causal_mask_in_scaled_dot_product_attention_export(self): |
| 128 | + sdpa_attention_forward = sdpa_attention.sdpa_attention_forward |
| 129 | + patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward |
| 130 | + kwargs = { |
| 131 | + "module": None, |
| 132 | + "query": torch.rand((1, 2, 1, 96), dtype=torch.float32), |
| 133 | + "key": torch.rand((1, 2, 4, 96), dtype=torch.float32), |
| 134 | + "value": torch.rand((1, 2, 4, 96), dtype=torch.float32), |
| 135 | + "attention_mask": None, |
| 136 | + "attention_dropout": 0, |
| 137 | + "scaling": 0.10206207261596575, |
| 138 | + "is_causal": True, |
| 139 | + } |
| 140 | + expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0] |
| 141 | + got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0] |
| 142 | + self.assertEqualArray(expected, got) |
| 143 | + |
| 144 | + class Model(torch.nn.Module): |
| 145 | + def forward(self, query, key, value): |
| 146 | + kwargs = { |
| 147 | + "module": None, |
| 148 | + "query": query, |
| 149 | + "key": key, |
| 150 | + "value": value, |
| 151 | + "attention_mask": None, |
| 152 | + "attention_dropout": 0, |
| 153 | + "scaling": 0.10206207261596575, |
| 154 | + "is_causal": True, |
| 155 | + } |
| 156 | + return patched_sdpa_attention_forward(**kwargs)[0] |
| 157 | + |
| 158 | + query, key, value = kwargs["query"], kwargs["key"], kwargs["value"] |
| 159 | + model = Model() |
| 160 | + got = model(query, key, value) |
| 161 | + self.assertEqualArray(expected, got) |
| 162 | + |
| 163 | + # static export |
| 164 | + ep = torch.export.export(model, (query, key, value)) |
| 165 | + got = ep.module()(query, key, value) |
| 166 | + self.assertEqualArray(expected, got) |
| 167 | + |
| 168 | + # dynamic |
| 169 | + ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"}) |
| 170 | + fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds) |
| 171 | + print("****", fake_inputs) |
| 172 | + epd = torch.export.export(model, fake_inputs) # , dynamic_shapes=use_dyn_not_str(ds)) |
| 173 | + print(epq) |
| 174 | + got = epd.module()(query, key, value) |
| 175 | + self.assertEqualArray(expected, got) |
| 176 | + |
124 | 177 |
|
125 | 178 | if __name__ == "__main__": |
126 | 179 | unittest.main(verbosity=2) |
0 commit comments