|
2 | 2 | import ml_dtypes |
3 | 3 | import onnx |
4 | 4 | import torch |
| 5 | +import transformers |
5 | 6 | from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout |
6 | | -from onnx_diagnostic.helpers import string_type |
| 7 | +from onnx_diagnostic.helpers import max_diff, string_type |
7 | 8 | from onnx_diagnostic.helpers.torch_test_helper import ( |
8 | 9 | dummy_llm, |
9 | 10 | to_numpy, |
|
13 | 14 | to_any, |
14 | 15 | torch_deepcopy, |
15 | 16 | ) |
16 | | -from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache |
| 17 | +from onnx_diagnostic.helpers.cache_helper import ( |
| 18 | + make_dynamic_cache, |
| 19 | + make_encoder_decoder_cache, |
| 20 | + make_mamba_cache, |
| 21 | + make_sliding_window_cache, |
| 22 | +) |
17 | 23 |
|
18 | 24 | TFLOAT = onnx.TensorProto.FLOAT |
19 | 25 |
|
@@ -85,19 +91,66 @@ def test_to_any(self): |
85 | 91 | at = to_any(a, torch.float16) |
86 | 92 | self.assertIn("T10r", string_type(at)) |
87 | 93 |
|
88 | | - def test_torch_deepcopy(self): |
| 94 | + def test_torch_deepcopy_cache_dce(self): |
89 | 95 | c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) |
90 | 96 | c2 = make_encoder_decoder_cache( |
91 | 97 | make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), |
92 | 98 | make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), |
93 | 99 | ) |
| 100 | + cc = torch_deepcopy(c2) |
| 101 | + self.assertEqual(type(c2), type(c2)) |
| 102 | + self.assertEqual(max_diff(c2, cc)["abs"], 0) |
94 | 103 | a = {"t": [(torch.tensor([1, 2]), c1, c2), {4, 5}]} |
95 | 104 | at = torch_deepcopy(a) |
96 | 105 | hash1 = string_type(at, with_shape=True, with_min_max=True) |
97 | 106 | c1.key_cache[0] += 1000 |
98 | 107 | hash2 = string_type(at, with_shape=True, with_min_max=True) |
99 | 108 | self.assertEqual(hash1, hash2) |
100 | 109 |
|
| 110 | + def test_torch_deepcopy_mamba_cache(self): |
| 111 | + cache = make_mamba_cache( |
| 112 | + [ |
| 113 | + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), |
| 114 | + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), |
| 115 | + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), |
| 116 | + ] |
| 117 | + ) |
| 118 | + at = torch_deepcopy(cache) |
| 119 | + self.assertEqual(type(cache), type(at)) |
| 120 | + self.assertEqual(max_diff(cache, at)["abs"], 0) |
| 121 | + hash1 = string_type(at, with_shape=True, with_min_max=True) |
| 122 | + cache.conv_states[0] += 1000 |
| 123 | + hash2 = string_type(at, with_shape=True, with_min_max=True) |
| 124 | + self.assertEqual(hash1, hash2) |
| 125 | + |
| 126 | + def test_torch_deepcopy_base_model_outputs(self): |
| 127 | + bo = transformers.modeling_outputs.BaseModelOutput( |
| 128 | + last_hidden_state=torch.rand((4, 4, 4)) |
| 129 | + ) |
| 130 | + at = torch_deepcopy(bo) |
| 131 | + self.assertEqual(max_diff(bo, at)["abs"], 0) |
| 132 | + self.assertEqual(type(bo), type(at)) |
| 133 | + hash1 = string_type(at, with_shape=True, with_min_max=True) |
| 134 | + bo.last_hidden_state[0] += 1000 |
| 135 | + hash2 = string_type(at, with_shape=True, with_min_max=True) |
| 136 | + self.assertEqual(hash1, hash2) |
| 137 | + |
| 138 | + def test_torch_deepcopy_sliding_windon_cache(self): |
| 139 | + cache = make_sliding_window_cache( |
| 140 | + [ |
| 141 | + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), |
| 142 | + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), |
| 143 | + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), |
| 144 | + ] |
| 145 | + ) |
| 146 | + at = torch_deepcopy(cache) |
| 147 | + self.assertEqual(type(cache), type(at)) |
| 148 | + self.assertEqual(max_diff(cache, at)["abs"], 0) |
| 149 | + hash1 = string_type(at, with_shape=True, with_min_max=True) |
| 150 | + cache.key_cache[0] += 1000 |
| 151 | + hash2 = string_type(at, with_shape=True, with_min_max=True) |
| 152 | + self.assertEqual(hash1, hash2) |
| 153 | + |
101 | 154 |
|
102 | 155 | if __name__ == "__main__": |
103 | 156 | unittest.main(verbosity=2) |
0 commit comments