Skip to content

Commit 320aae8

Browse files
committed
improve unittest
1 parent a9d482c commit 320aae8

File tree

2 files changed

+123
-75
lines changed

2 files changed

+123
-75
lines changed

_unittests/ut_torch_onnx/test_discrepancies.py

Lines changed: 123 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import unittest
33
import numpy as np
44
import onnx
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, has_onnxruntime
66
from onnx_diagnostic.reference import OnnxruntimeEvaluator
7+
from onnx_diagnostic.helpers import max_diff, string_diff
78

89

910
class TestDiscrepancies(ExtTestCase):
@@ -44,83 +45,131 @@ def qwen_sdpa_attention(
4445
attn_output = torch.cat(attn_outputs, dim=1)
4546
return attn_output
4647

47-
model = onnx.load(
48-
os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx")
49-
)
50-
sess = self.check_ort(model)
48+
for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]:
49+
if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"):
50+
# not available
51+
continue
52+
with self.subTest(model=model_name):
53+
model = onnx.load(os.path.join(os.path.dirname(__file__), "data", model_name))
54+
sess = self.check_ort(model)
5155

52-
feeds = dict(
53-
c_lifted_tensor_0=np.array([0], dtype=np.int64),
54-
cat_2=np.array(
55-
[
56-
0,
57-
64,
58-
128,
59-
192,
60-
256,
61-
304,
62-
368,
63-
432,
64-
496,
65-
560,
66-
608,
67-
672,
68-
736,
69-
800,
70-
864,
71-
912,
72-
976,
73-
1040,
74-
1104,
75-
1168,
76-
1216,
77-
1232,
78-
1248,
79-
1264,
80-
1280,
81-
1292,
82-
],
83-
dtype=np.int64,
84-
),
85-
unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32),
86-
unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32),
87-
unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32),
88-
)
56+
feeds = dict(
57+
c_lifted_tensor_0=np.array([0], dtype=np.int64),
58+
cat_2=np.array(
59+
[
60+
0,
61+
64,
62+
128,
63+
192,
64+
256,
65+
304,
66+
368,
67+
432,
68+
496,
69+
560,
70+
608,
71+
672,
72+
736,
73+
800,
74+
864,
75+
912,
76+
976,
77+
1040,
78+
1104,
79+
1168,
80+
1216,
81+
1232,
82+
1248,
83+
1264,
84+
1280,
85+
1292,
86+
],
87+
dtype=np.int64,
88+
),
89+
unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32),
90+
unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32),
91+
unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32),
92+
)
8993

90-
dummy_inputs = os.path.join(
91-
os.path.dirname(__file__),
92-
"..",
93-
"..",
94-
"dump_test",
95-
"replay",
96-
"qwen_sdpa_attention_loopa24",
97-
"onnx_inputs.pt",
98-
)
99-
if os.path.exists(dummy_inputs):
100-
print("-- use dummy inputs")
101-
feeds = {k: v.detach().cpu().numpy() for k, v in torch.load(dummy_inputs).items()}
102-
for k, v in feeds.items():
103-
print(f"-- {k}: {self.string_type(v, with_shape=True, with_min_max=True)}")
94+
dummy_inputs = os.path.join(
95+
os.path.dirname(__file__),
96+
"..",
97+
"..",
98+
"dump_test",
99+
"replay",
100+
"qwen_sdpa_attention_loopmha",
101+
"onnx_inputs.pt",
102+
)
103+
if os.path.exists(dummy_inputs):
104+
print("-- use dummy inputs")
104105

105-
# feeds["cat_2"] = np.array([0, 1292], dtype=np.int64)
106-
got = sess.run(None, feeds)
107-
self.assertEqual(len(got), 1)
108-
self.assertEqual((1, 1292, 16, 80), got[0].shape)
109-
expected = qwen_sdpa_attention(
110-
torch.from_numpy(feeds["unsqueeze_4"]),
111-
torch.from_numpy(feeds["unsqueeze_5"]),
112-
torch.from_numpy(feeds["unsqueeze_6"]),
113-
torch.from_numpy(feeds["cat_2"]),
114-
scaling=0.11180339753627777,
115-
num_heads=16,
116-
)
117-
self.assertEqualArray(expected, got[0], atol=1e-5)
106+
feeds1 = torch.load(dummy_inputs)
107+
res1 = qwen_sdpa_attention(
108+
feeds1["unsqueeze_4"],
109+
feeds1["unsqueeze_5"],
110+
feeds1["unsqueeze_6"],
111+
feeds1["cat_2"],
112+
scaling=0.11180339753627777,
113+
num_heads=16,
114+
)
115+
feeds1o = {k: v.detach().cpu().numpy() for k, v in feeds1.items()}
116+
reso1 = sess.run(None, feeds1o)[0]
117+
dummy_inputs2 = dummy_inputs.replace("onnx_inputs", "torch_inputs")
118+
assert dummy_inputs != dummy_inputs2
119+
feeds2 = torch.load(dummy_inputs2)
120+
res2 = qwen_sdpa_attention(
121+
feeds2["unsqueeze_4"],
122+
feeds2["unsqueeze_5"],
123+
feeds2["unsqueeze_6"],
124+
feeds2["cat_2"],
125+
scaling=0.11180339753627777,
126+
num_heads=16,
127+
)
128+
feeds2o = {k: v.detach().cpu().numpy() for k, v in feeds2.items()}
129+
reso2 = sess.run(None, feeds2o)[0]
130+
diff = max_diff(res1, res2, hist=[0.1])
131+
print(f"-- diff torch-onnx: {string_diff(diff)}")
132+
diff = max_diff(res2, reso2, hist=[0.1])
133+
print(f"-- diff torch-onnxo1: {string_diff(diff)}")
134+
diff = max_diff(res1, reso1, hist=[0.1])
135+
print(f"-- diff torch-onnxo2: {string_diff(diff)}")
136+
if diff["abs"] > 0.1:
137+
for k in feeds1:
138+
print(
139+
f"-- {k}: "
140+
f"{string_diff(max_diff(feeds1[k], feeds2[k], hist=[0.1]))}"
141+
)
142+
143+
feeds = {
144+
k: v.detach().cpu().numpy()
145+
for k, v in torch.load(dummy_inputs).items()
146+
}
147+
148+
for k, v in feeds.items():
149+
print(
150+
f"-- {k}: "
151+
f"{self.string_type(v, with_shape=True, with_min_max=True)}"
152+
)
153+
154+
# feeds["cat_2"] = np.array([0, 1292], dtype=np.int64)
155+
got = sess.run(None, feeds)
156+
self.assertEqual(len(got), 1)
157+
self.assertEqual((1, 1292, 16, 80), got[0].shape)
158+
expected = qwen_sdpa_attention(
159+
torch.from_numpy(feeds["unsqueeze_4"]),
160+
torch.from_numpy(feeds["unsqueeze_5"]),
161+
torch.from_numpy(feeds["unsqueeze_6"]),
162+
torch.from_numpy(feeds["cat_2"]),
163+
scaling=0.11180339753627777,
164+
num_heads=16,
165+
)
166+
self.assertEqualArray(expected, got[0], atol=1e-5)
118167

119-
tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
120-
ev = OnnxruntimeEvaluator(model)
121-
got2 = ev.run(None, tfeeds)
122-
self.assertEqual(len(got2), 1)
123-
self.assertEqualArray(got[0], got2[0], atol=1e-5)
168+
tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
169+
ev = OnnxruntimeEvaluator(model)
170+
got2 = ev.run(None, tfeeds)
171+
self.assertEqual(len(got2), 1)
172+
self.assertEqualArray(got[0], got2[0], atol=1e-5)
124173

125174

126175
if __name__ == "__main__":

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,6 @@ def _mkv_(name, itype, irank):
13411341
# there are hidden inputs
13421342
for att in node.attribute:
13431343
if att.type == onnx.AttributeProto.GRAPH:
1344-
print("++++", get_hidden_inputs(att.g))
13451344
not_known |= get_hidden_inputs(att.g)
13461345

13471346
model = oh.make_model(

0 commit comments

Comments
 (0)