|
2 | 2 | import unittest |
3 | 3 | import numpy as np |
4 | 4 | import onnx |
5 | | -from onnx_diagnostic.ext_test_case import ExtTestCase |
| 5 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings |
| 6 | +from onnx_diagnostic.reference import OnnxruntimeEvaluator |
6 | 7 |
|
7 | 8 |
|
8 | 9 | class TestDiscrepancies(ExtTestCase): |
| 10 | + @ignore_warnings(DeprecationWarning) |
9 | 11 | def test_attention_opset15_in_a_loop(self): |
| 12 | + import torch |
| 13 | + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_attention import ( # noqa: E501 |
| 14 | + patched_sdpa_attention_forward, |
| 15 | + ) |
| 16 | + |
| 17 | + def qwen_sdpa_attention( |
| 18 | + query_states: torch.Tensor, |
| 19 | + key_states: torch.Tensor, |
| 20 | + value_states: torch.Tensor, |
| 21 | + cu_seqlens: torch.Tensor, |
| 22 | + scaling: float = 0, |
| 23 | + num_heads: int = 16, |
| 24 | + ) -> torch.Tensor: |
| 25 | + lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| 26 | + splits = [ |
| 27 | + torch.split(tensor, lengths.tolist(), dim=2) |
| 28 | + for tensor in (query_states, key_states, value_states) |
| 29 | + ] |
| 30 | + |
| 31 | + attn_outputs = [ |
| 32 | + patched_sdpa_attention_forward( |
| 33 | + None, |
| 34 | + q, |
| 35 | + k, |
| 36 | + v, |
| 37 | + attention_mask=None, |
| 38 | + scaling=scaling, |
| 39 | + dropout=0.0, |
| 40 | + is_causal=False, |
| 41 | + )[0] |
| 42 | + for q, k, v in zip(*splits) |
| 43 | + ] |
| 44 | + attn_output = torch.cat(attn_outputs, dim=1) |
| 45 | + return attn_output |
| 46 | + |
10 | 47 | model = onnx.load( |
11 | 48 | os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx") |
12 | 49 | ) |
13 | 50 | sess = self.check_ort(model) |
| 51 | + |
14 | 52 | feeds = dict( |
15 | 53 | c_lifted_tensor_0=np.array([0], dtype=np.int64), |
16 | 54 | cat_2=np.array( |
@@ -48,9 +86,41 @@ def test_attention_opset15_in_a_loop(self): |
48 | 86 | unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
49 | 87 | unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
50 | 88 | ) |
| 89 | + |
| 90 | + dummy_inputs = os.path.join( |
| 91 | + os.path.dirname(__file__), |
| 92 | + "..", |
| 93 | + "..", |
| 94 | + "dump_test", |
| 95 | + "replay", |
| 96 | + "qwen_sdpa_attention_loopa24", |
| 97 | + "onnx_inputs.pt", |
| 98 | + ) |
| 99 | + if os.path.exists(dummy_inputs): |
| 100 | + print("-- use dummy inputs") |
| 101 | + feeds = {k: v.detach().cpu().numpy() for k, v in torch.load(dummy_inputs).items()} |
| 102 | + for k, v in feeds.items(): |
| 103 | + print(f"-- {k}: {self.string_type(v, with_shape=True, with_min_max=True)}") |
| 104 | + |
| 105 | + # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) |
51 | 106 | got = sess.run(None, feeds) |
52 | 107 | self.assertEqual(len(got), 1) |
53 | 108 | self.assertEqual((1, 1292, 16, 80), got[0].shape) |
| 109 | + expected = qwen_sdpa_attention( |
| 110 | + torch.from_numpy(feeds["unsqueeze_4"]), |
| 111 | + torch.from_numpy(feeds["unsqueeze_5"]), |
| 112 | + torch.from_numpy(feeds["unsqueeze_6"]), |
| 113 | + torch.from_numpy(feeds["cat_2"]), |
| 114 | + scaling=0.11180339753627777, |
| 115 | + num_heads=16, |
| 116 | + ) |
| 117 | + self.assertEqualArray(expected, got[0], atol=1e-5) |
| 118 | + |
| 119 | + tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} |
| 120 | + ev = OnnxruntimeEvaluator(model) |
| 121 | + got2 = ev.run(None, tfeeds) |
| 122 | + self.assertEqual(len(got2), 1) |
| 123 | + self.assertEqualArray(got[0], got2[0], atol=1e-5) |
54 | 124 |
|
55 | 125 |
|
56 | 126 | if __name__ == "__main__": |
|
0 commit comments