|
2 | 2 | import unittest |
3 | 3 | import numpy as np |
4 | 4 | import onnx |
5 | | -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings |
| 5 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, has_onnxruntime |
6 | 6 | from onnx_diagnostic.reference import OnnxruntimeEvaluator |
| 7 | +from onnx_diagnostic.helpers import max_diff, string_diff |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class TestDiscrepancies(ExtTestCase): |
@@ -44,83 +45,131 @@ def qwen_sdpa_attention( |
44 | 45 | attn_output = torch.cat(attn_outputs, dim=1) |
45 | 46 | return attn_output |
46 | 47 |
|
47 | | - model = onnx.load( |
48 | | - os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx") |
49 | | - ) |
50 | | - sess = self.check_ort(model) |
| 48 | + for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]: |
| 49 | + if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"): |
| 50 | + # not available |
| 51 | + continue |
| 52 | + with self.subTest(model=model_name): |
| 53 | + model = onnx.load(os.path.join(os.path.dirname(__file__), "data", model_name)) |
| 54 | + sess = self.check_ort(model) |
51 | 55 |
|
52 | | - feeds = dict( |
53 | | - c_lifted_tensor_0=np.array([0], dtype=np.int64), |
54 | | - cat_2=np.array( |
55 | | - [ |
56 | | - 0, |
57 | | - 64, |
58 | | - 128, |
59 | | - 192, |
60 | | - 256, |
61 | | - 304, |
62 | | - 368, |
63 | | - 432, |
64 | | - 496, |
65 | | - 560, |
66 | | - 608, |
67 | | - 672, |
68 | | - 736, |
69 | | - 800, |
70 | | - 864, |
71 | | - 912, |
72 | | - 976, |
73 | | - 1040, |
74 | | - 1104, |
75 | | - 1168, |
76 | | - 1216, |
77 | | - 1232, |
78 | | - 1248, |
79 | | - 1264, |
80 | | - 1280, |
81 | | - 1292, |
82 | | - ], |
83 | | - dtype=np.int64, |
84 | | - ), |
85 | | - unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
86 | | - unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
87 | | - unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
88 | | - ) |
| 56 | + feeds = dict( |
| 57 | + c_lifted_tensor_0=np.array([0], dtype=np.int64), |
| 58 | + cat_2=np.array( |
| 59 | + [ |
| 60 | + 0, |
| 61 | + 64, |
| 62 | + 128, |
| 63 | + 192, |
| 64 | + 256, |
| 65 | + 304, |
| 66 | + 368, |
| 67 | + 432, |
| 68 | + 496, |
| 69 | + 560, |
| 70 | + 608, |
| 71 | + 672, |
| 72 | + 736, |
| 73 | + 800, |
| 74 | + 864, |
| 75 | + 912, |
| 76 | + 976, |
| 77 | + 1040, |
| 78 | + 1104, |
| 79 | + 1168, |
| 80 | + 1216, |
| 81 | + 1232, |
| 82 | + 1248, |
| 83 | + 1264, |
| 84 | + 1280, |
| 85 | + 1292, |
| 86 | + ], |
| 87 | + dtype=np.int64, |
| 88 | + ), |
| 89 | + unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
| 90 | + unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
| 91 | + unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), |
| 92 | + ) |
89 | 93 |
|
90 | | - dummy_inputs = os.path.join( |
91 | | - os.path.dirname(__file__), |
92 | | - "..", |
93 | | - "..", |
94 | | - "dump_test", |
95 | | - "replay", |
96 | | - "qwen_sdpa_attention_loopa24", |
97 | | - "onnx_inputs.pt", |
98 | | - ) |
99 | | - if os.path.exists(dummy_inputs): |
100 | | - print("-- use dummy inputs") |
101 | | - feeds = {k: v.detach().cpu().numpy() for k, v in torch.load(dummy_inputs).items()} |
102 | | - for k, v in feeds.items(): |
103 | | - print(f"-- {k}: {self.string_type(v, with_shape=True, with_min_max=True)}") |
| 94 | + dummy_inputs = os.path.join( |
| 95 | + os.path.dirname(__file__), |
| 96 | + "..", |
| 97 | + "..", |
| 98 | + "dump_test", |
| 99 | + "replay", |
| 100 | + "qwen_sdpa_attention_loopmha", |
| 101 | + "onnx_inputs.pt", |
| 102 | + ) |
| 103 | + if os.path.exists(dummy_inputs): |
| 104 | + print("-- use dummy inputs") |
104 | 105 |
|
105 | | - # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) |
106 | | - got = sess.run(None, feeds) |
107 | | - self.assertEqual(len(got), 1) |
108 | | - self.assertEqual((1, 1292, 16, 80), got[0].shape) |
109 | | - expected = qwen_sdpa_attention( |
110 | | - torch.from_numpy(feeds["unsqueeze_4"]), |
111 | | - torch.from_numpy(feeds["unsqueeze_5"]), |
112 | | - torch.from_numpy(feeds["unsqueeze_6"]), |
113 | | - torch.from_numpy(feeds["cat_2"]), |
114 | | - scaling=0.11180339753627777, |
115 | | - num_heads=16, |
116 | | - ) |
117 | | - self.assertEqualArray(expected, got[0], atol=1e-5) |
| 106 | + feeds1 = torch.load(dummy_inputs) |
| 107 | + res1 = qwen_sdpa_attention( |
| 108 | + feeds1["unsqueeze_4"], |
| 109 | + feeds1["unsqueeze_5"], |
| 110 | + feeds1["unsqueeze_6"], |
| 111 | + feeds1["cat_2"], |
| 112 | + scaling=0.11180339753627777, |
| 113 | + num_heads=16, |
| 114 | + ) |
| 115 | + feeds1o = {k: v.detach().cpu().numpy() for k, v in feeds1.items()} |
| 116 | + reso1 = sess.run(None, feeds1o)[0] |
| 117 | + dummy_inputs2 = dummy_inputs.replace("onnx_inputs", "torch_inputs") |
| 118 | + assert dummy_inputs != dummy_inputs2 |
| 119 | + feeds2 = torch.load(dummy_inputs2) |
| 120 | + res2 = qwen_sdpa_attention( |
| 121 | + feeds2["unsqueeze_4"], |
| 122 | + feeds2["unsqueeze_5"], |
| 123 | + feeds2["unsqueeze_6"], |
| 124 | + feeds2["cat_2"], |
| 125 | + scaling=0.11180339753627777, |
| 126 | + num_heads=16, |
| 127 | + ) |
| 128 | + feeds2o = {k: v.detach().cpu().numpy() for k, v in feeds2.items()} |
| 129 | + reso2 = sess.run(None, feeds2o)[0] |
| 130 | + diff = max_diff(res1, res2, hist=[0.1]) |
| 131 | + print(f"-- diff torch-onnx: {string_diff(diff)}") |
| 132 | + diff = max_diff(res2, reso2, hist=[0.1]) |
| 133 | + print(f"-- diff torch-onnxo1: {string_diff(diff)}") |
| 134 | + diff = max_diff(res1, reso1, hist=[0.1]) |
| 135 | + print(f"-- diff torch-onnxo2: {string_diff(diff)}") |
| 136 | + if diff["abs"] > 0.1: |
| 137 | + for k in feeds1: |
| 138 | + print( |
| 139 | + f"-- {k}: " |
| 140 | + f"{string_diff(max_diff(feeds1[k], feeds2[k], hist=[0.1]))}" |
| 141 | + ) |
| 142 | + |
| 143 | + feeds = { |
| 144 | + k: v.detach().cpu().numpy() |
| 145 | + for k, v in torch.load(dummy_inputs).items() |
| 146 | + } |
| 147 | + |
| 148 | + for k, v in feeds.items(): |
| 149 | + print( |
| 150 | + f"-- {k}: " |
| 151 | + f"{self.string_type(v, with_shape=True, with_min_max=True)}" |
| 152 | + ) |
| 153 | + |
| 154 | + # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) |
| 155 | + got = sess.run(None, feeds) |
| 156 | + self.assertEqual(len(got), 1) |
| 157 | + self.assertEqual((1, 1292, 16, 80), got[0].shape) |
| 158 | + expected = qwen_sdpa_attention( |
| 159 | + torch.from_numpy(feeds["unsqueeze_4"]), |
| 160 | + torch.from_numpy(feeds["unsqueeze_5"]), |
| 161 | + torch.from_numpy(feeds["unsqueeze_6"]), |
| 162 | + torch.from_numpy(feeds["cat_2"]), |
| 163 | + scaling=0.11180339753627777, |
| 164 | + num_heads=16, |
| 165 | + ) |
| 166 | + self.assertEqualArray(expected, got[0], atol=1e-5) |
118 | 167 |
|
119 | | - tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} |
120 | | - ev = OnnxruntimeEvaluator(model) |
121 | | - got2 = ev.run(None, tfeeds) |
122 | | - self.assertEqual(len(got2), 1) |
123 | | - self.assertEqualArray(got[0], got2[0], atol=1e-5) |
| 168 | + tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} |
| 169 | + ev = OnnxruntimeEvaluator(model) |
| 170 | + got2 = ev.run(None, tfeeds) |
| 171 | + self.assertEqual(len(got2), 1) |
| 172 | + self.assertEqualArray(got[0], got2[0], atol=1e-5) |
124 | 173 |
|
125 | 174 |
|
126 | 175 | if __name__ == "__main__": |
|
0 commit comments