Skip to content

Commit c3847d9

Browse files
committed
fix issues
1 parent d8eb27a commit c3847d9

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
66
from onnx_diagnostic.export import ModelInputs
77
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
8+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
89

910

1011
class TestDynamicShapes(ExtTestCase):
@@ -529,25 +530,26 @@ def test_couple_input_ds_cache(self):
529530

530531
kwargs = {"A": T3x4, "B": (T3x1, cache)}
531532
Cls = CoupleInputsDynamicShapes
532-
self.assertEqual(
533-
[],
534-
Cls(
535-
(),
536-
kwargs,
537-
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
538-
).invalid_paths(),
539-
)
540-
self.assertEqual(
541-
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
542-
Cls(
543-
(),
544-
kwargs,
545-
{
546-
"A": ds_batch,
547-
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
548-
},
549-
).invalid_paths(),
550-
)
533+
with bypass_export_some_errors(patch_transformers=True):
534+
self.assertEqual(
535+
[],
536+
Cls(
537+
(),
538+
kwargs,
539+
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
540+
).invalid_paths(),
541+
)
542+
self.assertEqual(
543+
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
544+
Cls(
545+
(),
546+
kwargs,
547+
{
548+
"A": ds_batch,
549+
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
550+
},
551+
).invalid_paths(),
552+
)
551553

552554

553555
if __name__ == "__main__":

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ def _valid_shapes(
541541
yield path
542542
else:
543543
# A custom class.
544+
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
545+
f"Class {inputs.__class__.__name__!r} was not registered using "
546+
f"torch.utils._pytree.register_pytree_node, it is not possible to "
547+
f"map this class with the given dynamic shapes."
548+
)
544549
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
545550
for path in cls._valid_shapes(
546551
flat, ds, prefix=(*prefix, inputs.__class__.__name__)

0 commit comments

Comments
 (0)