|
10 | 10 | from onnx_diagnostic.helpers.cache_helper import ( |
11 | 11 | make_encoder_decoder_cache, |
12 | 12 | make_dynamic_cache, |
| 13 | + make_static_cache, |
13 | 14 | make_sliding_window_cache, |
14 | 15 | flatten_unflatten_for_dynamic_shapes, |
15 | 16 | ) |
@@ -180,7 +181,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self): |
180 | 181 | self.assertEqualAny([cache], cache2) |
181 | 182 |
|
182 | 183 | @ignore_warnings(UserWarning) |
183 | | - @requires_torch("2.8") |
| 184 | + @requires_torch("2.7.99") |
184 | 185 | def test_sliding_window_cache_export(self): |
185 | 186 | class Model(torch.nn.Module): |
186 | 187 | def forward(self, cache): |
@@ -274,6 +275,69 @@ def forward(self, cache): |
274 | 275 | with torch_export_patches(): |
275 | 276 | torch.export.export(model, (bo,), dynamic_shapes=(ds,)) |
276 | 277 |
|
| 278 | + @ignore_warnings(UserWarning) |
| 279 | + @requires_torch("2.7.99") |
| 280 | + def test_static_cache(self): |
| 281 | + bo = make_static_cache( |
| 282 | + [ |
| 283 | + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), |
| 284 | + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), |
| 285 | + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), |
| 286 | + ], |
| 287 | + max_cache_len=15, |
| 288 | + ) |
| 289 | + self.assertEqual(bo.__class__.__name__, "StaticCache") |
| 290 | + bo2 = torch_deepcopy([bo]) |
| 291 | + self.assertIsInstance(bo2, list) |
| 292 | + self.assertEqual( |
| 293 | + "StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], " |
| 294 | + "value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])", |
| 295 | + self.string_type(bo, with_shape=True), |
| 296 | + ) |
| 297 | + |
| 298 | + with torch_export_patches(): |
| 299 | + # internal function |
| 300 | + bo2 = torch_deepcopy([bo]) |
| 301 | + self.assertIsInstance(bo2, list) |
| 302 | + self.assertEqual(bo2[0].__class__.__name__, "StaticCache") |
| 303 | + self.assertEqualAny([bo], bo2) |
| 304 | + self.assertEqual( |
| 305 | + "StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], " |
| 306 | + "value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])", |
| 307 | + self.string_type(bo, with_shape=True), |
| 308 | + ) |
| 309 | + |
| 310 | + # serialization |
| 311 | + flat, _spec = torch.utils._pytree.tree_flatten(bo) |
| 312 | + self.assertEqual( |
| 313 | + "#6[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7]", |
| 314 | + self.string_type(flat, with_shape=True), |
| 315 | + ) |
| 316 | + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) |
| 317 | + self.assertEqual( |
| 318 | + self.string_type(bo, with_shape=True, with_min_max=True), |
| 319 | + self.string_type(bo2, with_shape=True, with_min_max=True), |
| 320 | + ) |
| 321 | + |
| 322 | + # flatten_unflatten |
| 323 | + flat, _spec = torch.utils._pytree.tree_flatten(bo) |
| 324 | + unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) |
| 325 | + self.assertIsInstance(unflat, dict) |
| 326 | + self.assertEqual(list(unflat), ["key_cache", "value_cache"]) |
| 327 | + |
| 328 | + # export |
| 329 | + class Model(torch.nn.Module): |
| 330 | + def forward(self, cache): |
| 331 | + return cache.key_cache[0] |
| 332 | + |
| 333 | + model = Model() |
| 334 | + model(bo) |
| 335 | + DYN = torch.export.Dim.DYNAMIC |
| 336 | + ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] |
| 337 | + |
| 338 | + with torch_export_patches(patch_transformers=True, stop_if_static=1): |
| 339 | + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) |
| 340 | + |
277 | 341 |
|
278 | 342 | if __name__ == "__main__": |
279 | 343 | unittest.main(verbosity=2) |
0 commit comments