|
| 1 | +import unittest |
| 2 | +import torch |
| 3 | +from onnx_diagnostic.ext_test_case import ExtTestCase |
| 4 | +from onnx_diagnostic.helpers import string_type |
| 5 | +from onnx_diagnostic.helpers.cache_helper import ( |
| 6 | + make_dynamic_cache, |
| 7 | + flatten_unflatten_for_dynamic_shapes, |
| 8 | +) |
| 9 | +from onnx_diagnostic.export import ModelInputs |
| 10 | + |
| 11 | + |
| 12 | +class TestSerialization(ExtTestCase): |
| 13 | + def _get_cache(self, n_layers=2, bsize=2, nheads=4, slen=1, dim=7): |
| 14 | + return make_dynamic_cache( |
| 15 | + [ |
| 16 | + (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim)) |
| 17 | + for i in range(n_layers) |
| 18 | + ] |
| 19 | + ) |
| 20 | + |
| 21 | + def test_dynamic_cache(self): |
| 22 | + class Model(torch.nn.Module): |
| 23 | + def forward(self, cache): |
| 24 | + return cache.key_cache[0] |
| 25 | + |
| 26 | + cache = self._get_cache() |
| 27 | + flat_unflat = flatten_unflatten_for_dynamic_shapes(cache) |
| 28 | + s = string_type(flat_unflat, with_shape=True) |
| 29 | + self.assertEqual(s, "#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]") |
| 30 | + DYN = torch.export.Dim.DYNAMIC |
| 31 | + ds = {0: DYN, 1: DYN, 3: DYN} |
| 32 | + dynamic_shapes = ([[ds, ds], [ds, ds]],) |
| 33 | + exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes) |
| 34 | + self.assertNotEmpty(exp) |
| 35 | + |
| 36 | + def test_dynamic_cache_guess_static(self): |
| 37 | + class Model(torch.nn.Module): |
| 38 | + def forward(self, cache): |
| 39 | + return cache.key_cache[0] |
| 40 | + |
| 41 | + cache = self._get_cache() |
| 42 | + md = ModelInputs(Model(), [(cache,)]) |
| 43 | + guessed = md.guess_dynamic_shapes() |
| 44 | + self.assertEqual(guessed, (([[{}, {}], [{}, {}]],), {})) |
| 45 | + |
| 46 | + def test_dynamic_cache_guess_dynamic(self): |
| 47 | + class Model(torch.nn.Module): |
| 48 | + def forward(self, cache): |
| 49 | + return cache.key_cache[0] |
| 50 | + |
| 51 | + md = ModelInputs( |
| 52 | + Model(), [(self._get_cache(),), (self._get_cache(bsize=3, nheads=5),)] |
| 53 | + ) |
| 54 | + guessed = md.guess_dynamic_shapes() |
| 55 | + DYN = torch.export.Dim.DYNAMIC |
| 56 | + self.assertEqual( |
| 57 | + guessed, |
| 58 | + ( |
| 59 | + ( |
| 60 | + [ |
| 61 | + [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], |
| 62 | + [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], |
| 63 | + ], |
| 64 | + ), |
| 65 | + {}, |
| 66 | + ), |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +if __name__ == "__main__": |
| 71 | + unittest.main(verbosity=2) |
0 commit comments