|
5 | 5 | from onnx_diagnostic.helpers.cache_helper import ( |
6 | 6 | make_encoder_decoder_cache, |
7 | 7 | make_dynamic_cache, |
| 8 | + make_sliding_window_cache, |
8 | 9 | flatten_unflatten_for_dynamic_shapes, |
9 | 10 | ) |
10 | 11 | from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( |
@@ -164,6 +165,52 @@ def test_base_model_output_unflatten_flatten(self): |
164 | 165 | self.assertIsInstance(unflat, dict) |
165 | 166 | self.assertEqual(list(unflat), ["last_hidden_state"]) |
166 | 167 |
|
| 168 | + @ignore_warnings(UserWarning) |
| 169 | + def test_base_sliding_window_cache_unflatten_flatten(self): |
| 170 | + cache = make_sliding_window_cache( |
| 171 | + [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] |
| 172 | + ) |
| 173 | + with bypass_export_some_errors(): |
| 174 | + cache2 = torch_deepcopy([cache]) |
| 175 | + self.assertEqualAny([cache], cache2) |
| 176 | + |
| 177 | + @ignore_warnings(UserWarning) |
| 178 | + def test_sliding_window_cache_export(self): |
| 179 | + class Model(torch.nn.Module): |
| 180 | + def forward(self, cache): |
| 181 | + return cache.key_cache[0] |
| 182 | + |
| 183 | + cache = make_sliding_window_cache( |
| 184 | + [ |
| 185 | + (torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))), |
| 186 | + (torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))), |
| 187 | + ] |
| 188 | + ) |
| 189 | + model = Model() |
| 190 | + model(cache) |
| 191 | + DYN = torch.export.Dim.DYNAMIC |
| 192 | + ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]] |
| 193 | + |
| 194 | + with bypass_export_some_errors(patch_transformers=True): |
| 195 | + torch.export.export(model, (cache,), dynamic_shapes=(ds,)) |
| 196 | + |
| 197 | + @ignore_warnings(UserWarning) |
| 198 | + def test_sliding_window_cache_flatten(self): |
| 199 | + cache = make_sliding_window_cache( |
| 200 | + [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] |
| 201 | + ) |
| 202 | + with bypass_export_some_errors(): |
| 203 | + flat, _spec = torch.utils._pytree.tree_flatten(cache) |
| 204 | + self.assertEqual( |
| 205 | + "#2[T1s4x4x4x4,T1s4x4x4x4]", |
| 206 | + self.string_type(flat, with_shape=True), |
| 207 | + ) |
| 208 | + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) |
| 209 | + self.assertEqual( |
| 210 | + self.string_type(cache, with_shape=True, with_min_max=True), |
| 211 | + self.string_type(cache2, with_shape=True, with_min_max=True), |
| 212 | + ) |
| 213 | + |
167 | 214 |
|
168 | 215 | if __name__ == "__main__": |
169 | 216 | unittest.main(verbosity=2) |
0 commit comments