|
3 | 3 | from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers |
4 | 4 | from onnx_diagnostic.helpers import flatten_object |
5 | 5 | from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache |
6 | | -from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, fake_reshape |
| 6 | +from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, FakeTensorContext |
7 | 7 |
|
8 | 8 |
|
9 | 9 | class TestMakeTensorHelper(ExtTestCase): |
10 | 10 |
|
| 11 | + @requires_transformers("4.55") |
| 12 | + def test_fake_inputs(self): |
| 13 | + inputs, _ = make_fake( |
| 14 | + dict( |
| 15 | + input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), |
| 16 | + attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), |
| 17 | + position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), |
| 18 | + past_key_values=make_dynamic_cache( |
| 19 | + [ |
| 20 | + ( |
| 21 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 22 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 23 | + ), |
| 24 | + ( |
| 25 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 26 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 27 | + ), |
| 28 | + ] |
| 29 | + ), |
| 30 | + ) |
| 31 | + ) |
| 32 | + flat = flatten_object(inputs, drop_keys=True) |
| 33 | + for t in flat: |
| 34 | + self.assertIsInstance(t, torch.Tensor) |
| 35 | + assert all( |
| 36 | + isinstance(s, torch.SymInt) for s in t.shape |
| 37 | + ), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}" |
| 38 | + |
11 | 39 | def test_fake_reshape_generic(self): |
12 | 40 | t = torch.zeros((2, 3, 4, 5), dtype=torch.float32) |
13 | | - reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"}) |
| 41 | + reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"}) |
14 | 42 | self.assertIsInstance(reshaped.shape[0], torch.SymInt) |
15 | 43 | self.assertIsInstance(reshaped.shape[2], torch.SymInt) |
16 | 44 | self.assertEqual(reshaped.shape[1], 3) |
17 | 45 | self.assertEqual(reshaped.shape[3], 5) |
18 | 46 |
|
19 | 47 | def test_fake_reshape_dim_1(self): |
20 | 48 | t = torch.zeros((1, 3, 4, 5), dtype=torch.float32) |
21 | | - reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"}) |
| 49 | + reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"}) |
22 | 50 | self.assertIsInstance(reshaped.shape[0], torch.SymInt) |
23 | 51 | self.assertIsInstance(reshaped.shape[2], torch.SymInt) |
24 | 52 | self.assertEqual(reshaped.shape[1], 3) |
25 | 53 | self.assertEqual(reshaped.shape[3], 5) |
26 | 54 |
|
27 | 55 | def test_fake_reshape_dim_0(self): |
28 | 56 | t = torch.zeros((0, 3, 4, 5), dtype=torch.float32) |
29 | | - reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"}) |
| 57 | + reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"}) |
30 | 58 | self.assertIsInstance(reshaped.shape[0], torch.SymInt) |
31 | 59 | self.assertIsInstance(reshaped.shape[2], torch.SymInt) |
32 | 60 | self.assertEqual(reshaped.shape[1], 3) |
33 | 61 | self.assertEqual(reshaped.shape[3], 5) |
34 | 62 |
|
35 | 63 | def test_fake_reshape_different(self): |
36 | 64 | t = torch.zeros((2, 3, 2, 5), dtype=torch.float32) |
37 | | - reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"}) |
| 65 | + reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"}) |
38 | 66 | self.assertIsInstance(reshaped.shape[0], torch.SymInt) |
39 | 67 | self.assertIsInstance(reshaped.shape[2], torch.SymInt) |
40 | 68 | self.assertEqual(reshaped.shape[1], 3) |
41 | 69 | self.assertEqual(reshaped.shape[3], 5) |
42 | 70 | self.assertNotEqual(reshaped.shape[0], reshaped.shape[2]) |
43 | 71 |
|
44 | | - @requires_transformers("4.55") |
45 | | - def test_fake_inputs(self): |
46 | | - inputs, _ = make_fake( |
47 | | - dict( |
48 | | - input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), |
49 | | - attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), |
50 | | - position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), |
51 | | - past_key_values=make_dynamic_cache( |
52 | | - [ |
53 | | - ( |
54 | | - torch.rand((2, 32, 30, 96), dtype=torch.float16), |
55 | | - torch.rand((2, 32, 30, 96), dtype=torch.float16), |
56 | | - ), |
57 | | - ( |
58 | | - torch.rand((2, 32, 30, 96), dtype=torch.float16), |
59 | | - torch.rand((2, 32, 30, 96), dtype=torch.float16), |
60 | | - ), |
61 | | - ] |
62 | | - ), |
63 | | - ) |
64 | | - ) |
65 | | - flat = flatten_object(inputs, drop_keys=True) |
66 | | - for t in flat: |
67 | | - self.assertIsInstance(t, torch.Tensor) |
68 | | - assert all( |
69 | | - isinstance(s, torch.SymInt) for s in t.shape |
70 | | - ), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}" |
71 | | - |
72 | 72 |
|
73 | 73 | if __name__ == "__main__": |
74 | 74 | unittest.main(verbosity=2) |
0 commit comments