|
| 1 | +import unittest |
| 2 | +import torch |
| 3 | +from transformers.modeling_outputs import BaseModelOutput |
| 4 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings |
| 5 | +from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache |
| 6 | +from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( |
| 7 | + bypass_export_some_errors, |
| 8 | +) |
| 9 | +from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy |
| 10 | + |
| 11 | + |
| 12 | +class TestPatchSerialization(ExtTestCase): |
| 13 | + @ignore_warnings(UserWarning) |
| 14 | + def test_encoder_decoder_cache_flatten(self): |
| 15 | + cache = make_encoder_decoder_cache( |
| 16 | + make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), |
| 17 | + make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), |
| 18 | + ) |
| 19 | + with bypass_export_some_errors(): |
| 20 | + flat, _spec = torch.utils._pytree.tree_flatten(cache) |
| 21 | + self.assertEqual( |
| 22 | + "#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]", |
| 23 | + self.string_type(flat, with_shape=True), |
| 24 | + ) |
| 25 | + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) |
| 26 | + self.assertEqual( |
| 27 | + self.string_type(cache, with_shape=True, with_min_max=True), |
| 28 | + self.string_type(cache2, with_shape=True, with_min_max=True), |
| 29 | + ) |
| 30 | + |
| 31 | + @ignore_warnings(UserWarning) |
| 32 | + def test_encoder_decoder_cache_deepcopy(self): |
| 33 | + cache = make_encoder_decoder_cache( |
| 34 | + make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), |
| 35 | + make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), |
| 36 | + ) |
| 37 | + with bypass_export_some_errors(): |
| 38 | + cache2 = torch_deepcopy([cache]) |
| 39 | + self.assertEqualAny([cache], cache2) |
| 40 | + |
| 41 | + @ignore_warnings(UserWarning) |
| 42 | + def test_encoder_decoder_cache_export(self): |
| 43 | + class Model(torch.nn.Module): |
| 44 | + def forward(self, cache): |
| 45 | + return cache.self_attention_cache.key_cache[0] |
| 46 | + |
| 47 | + cache1 = make_dynamic_cache( |
| 48 | + [(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)] |
| 49 | + ) |
| 50 | + cache2 = make_dynamic_cache( |
| 51 | + [(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)] |
| 52 | + ) |
| 53 | + |
| 54 | + cache = make_encoder_decoder_cache(cache1, cache2) |
| 55 | + model = Model() |
| 56 | + model(cache) |
| 57 | + DYN = torch.export.Dim.DYNAMIC |
| 58 | + ds = [ |
| 59 | + [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], |
| 60 | + [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], |
| 61 | + ] |
| 62 | + |
| 63 | + with bypass_export_some_errors(patch_transformers=True): |
| 64 | + torch.export.export(model, (cache,), dynamic_shapes=(ds,)) |
| 65 | + |
| 66 | + @ignore_warnings(UserWarning) |
| 67 | + def test_dynamic_cache_flatten(self): |
| 68 | + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) |
| 69 | + with bypass_export_some_errors(): |
| 70 | + flat, _spec = torch.utils._pytree.tree_flatten(cache) |
| 71 | + self.assertEqual( |
| 72 | + "#2[T1s4x4x4,T1s4x4x4]", |
| 73 | + self.string_type(flat, with_shape=True), |
| 74 | + ) |
| 75 | + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) |
| 76 | + self.assertEqual( |
| 77 | + self.string_type(cache, with_shape=True, with_min_max=True), |
| 78 | + self.string_type(cache2, with_shape=True, with_min_max=True), |
| 79 | + ) |
| 80 | + |
| 81 | + @ignore_warnings(UserWarning) |
| 82 | + def test_dynamic_cache_export(self): |
| 83 | + class Model(torch.nn.Module): |
| 84 | + def forward(self, cache): |
| 85 | + return cache.key_cache[0] |
| 86 | + |
| 87 | + cache = make_dynamic_cache( |
| 88 | + [(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)] |
| 89 | + ) |
| 90 | + model = Model() |
| 91 | + model(cache) |
| 92 | + DYN = torch.export.Dim.DYNAMIC |
| 93 | + ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] |
| 94 | + |
| 95 | + with bypass_export_some_errors(): |
| 96 | + torch.export.export(model, (cache,), dynamic_shapes=(ds,)) |
| 97 | + |
| 98 | + @ignore_warnings(UserWarning) |
| 99 | + def test_dynamic_cache_deepcopy(self): |
| 100 | + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) |
| 101 | + with bypass_export_some_errors(): |
| 102 | + cache2 = torch_deepcopy([cache]) |
| 103 | + self.assertEqualAny([cache], cache2) |
| 104 | + |
| 105 | + @ignore_warnings(UserWarning) |
| 106 | + def test_base_model_output_deepcopy(self): |
| 107 | + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) |
| 108 | + self.assertEqual(bo.__class__.__name__, "BaseModelOutput") |
| 109 | + with bypass_export_some_errors(): |
| 110 | + bo2 = torch_deepcopy([bo]) |
| 111 | + self.assertIsInstance(bo2, list) |
| 112 | + self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput") |
| 113 | + self.assertEqualAny([bo], bo2) |
| 114 | + |
| 115 | + @ignore_warnings(UserWarning) |
| 116 | + def test_base_model_output_string_type(self): |
| 117 | + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) |
| 118 | + with bypass_export_some_errors(): |
| 119 | + self.assertEqual( |
| 120 | + "BaseModelOutput(last_hidden_state:T1s4x4x4)", |
| 121 | + self.string_type(bo, with_shape=True), |
| 122 | + ) |
| 123 | + |
| 124 | + @ignore_warnings(UserWarning) |
| 125 | + def test_base_model_output_flatten(self): |
| 126 | + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) |
| 127 | + with bypass_export_some_errors(): |
| 128 | + flat, _spec = torch.utils._pytree.tree_flatten(bo) |
| 129 | + self.assertEqual( |
| 130 | + "#1[T1s4x4x4]", |
| 131 | + self.string_type(flat, with_shape=True), |
| 132 | + ) |
| 133 | + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) |
| 134 | + self.assertEqual( |
| 135 | + self.string_type(bo, with_shape=True, with_min_max=True), |
| 136 | + self.string_type(bo2, with_shape=True, with_min_max=True), |
| 137 | + ) |
| 138 | + |
| 139 | + @ignore_warnings(UserWarning) |
| 140 | + def test_base_model_output_export(self): |
| 141 | + class Model(torch.nn.Module): |
| 142 | + def forward(self, cache): |
| 143 | + return cache.last_hidden_state[0] |
| 144 | + |
| 145 | + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) |
| 146 | + model = Model() |
| 147 | + model(bo) |
| 148 | + DYN = torch.export.Dim.DYNAMIC |
| 149 | + ds = [{0: DYN}] |
| 150 | + |
| 151 | + with bypass_export_some_errors(): |
| 152 | + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + unittest.main(verbosity=2) |
0 commit comments