Skip to content

Commit e31b475

Browse files
committed
unittest
1 parent 9b18b46 commit e31b475

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
import unittest
3+
import numpy as np
4+
import onnx
5+
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
7+
8+
class TestDiscrepancies(ExtTestCase):
9+
def test_attention_opset15_in_a_loop(self):
10+
model = onnx.load(
11+
os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx")
12+
)
13+
sess = self.check_ort(model)
14+
feeds = dict(
15+
c_lift_tensor_0=np.array([0], dtype=np.int64),
16+
cat_2=np.array(
17+
[
18+
0,
19+
64,
20+
128,
21+
192,
22+
256,
23+
304,
24+
368,
25+
432,
26+
496,
27+
560,
28+
608,
29+
672,
30+
736,
31+
800,
32+
864,
33+
912,
34+
976,
35+
1040,
36+
1104,
37+
1168,
38+
1216,
39+
1232,
40+
1248,
41+
1264,
42+
1280,
43+
1292,
44+
],
45+
dtype=np.int64,
46+
),
47+
unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32),
48+
)
49+
got = sess.run(None, feeds)
50+
self.assertEqual(len(got), 1)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)