|
5 | 5 | from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache |
6 | 6 | from onnx_diagnostic.export import ModelInputs |
7 | 7 | from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes |
| 8 | +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors |
8 | 9 |
|
9 | 10 |
|
10 | 11 | class TestDynamicShapes(ExtTestCase): |
@@ -529,25 +530,26 @@ def test_couple_input_ds_cache(self): |
529 | 530 |
|
530 | 531 | kwargs = {"A": T3x4, "B": (T3x1, cache)} |
531 | 532 | Cls = CoupleInputsDynamicShapes |
532 | | - self.assertEqual( |
533 | | - [], |
534 | | - Cls( |
535 | | - (), |
536 | | - kwargs, |
537 | | - {"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])}, |
538 | | - ).invalid_paths(), |
539 | | - ) |
540 | | - self.assertEqual( |
541 | | - [("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")], |
542 | | - Cls( |
543 | | - (), |
544 | | - kwargs, |
545 | | - { |
546 | | - "A": ds_batch, |
547 | | - "B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]), |
548 | | - }, |
549 | | - ).invalid_paths(), |
550 | | - ) |
| 533 | + with bypass_export_some_errors(patch_transformers=True): |
| 534 | + self.assertEqual( |
| 535 | + [], |
| 536 | + Cls( |
| 537 | + (), |
| 538 | + kwargs, |
| 539 | + {"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])}, |
| 540 | + ).invalid_paths(), |
| 541 | + ) |
| 542 | + self.assertEqual( |
| 543 | + [("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")], |
| 544 | + Cls( |
| 545 | + (), |
| 546 | + kwargs, |
| 547 | + { |
| 548 | + "A": ds_batch, |
| 549 | + "B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]), |
| 550 | + }, |
| 551 | + ).invalid_paths(), |
| 552 | + ) |
551 | 553 |
|
552 | 554 |
|
553 | 555 | if __name__ == "__main__": |
|
0 commit comments