Skip to content

Commit a4a5c0e

Browse files
committed
refactor algorithm to validate shapes
1 parent abbcc6b commit a4a5c0e

File tree

2 files changed

+91
-72
lines changed

2 files changed

+91
-72
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -467,13 +467,13 @@ def test_couple_input_ds_0(self):
467467
T1x4 = torch.rand((1, 4))
468468
T1x1 = torch.rand((1, 1))
469469
Cls = CoupleInputsDynamicShapes
470-
self.assertEqual([(0, "[0]")], Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths())
471-
self.assertEqual([(0, "[0]")], Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths())
470+
self.assertEqual(({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths())
471+
self.assertEqual(({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths())
472472
self.assertEqual(
473-
[("A", "[0]")], Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths()
473+
{"A": {0: "d=[1]"}}, Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths()
474474
)
475475
self.assertEqual(
476-
[("A", "[0]")], Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths()
476+
{"A": {0: "d=[1]"}}, Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths()
477477
)
478478

479479
def test_couple_input_ds_1(self):
@@ -483,8 +483,10 @@ def test_couple_input_ds_1(self):
483483
ds_batch_seq = {0: "batch", 1: "seq"}
484484
args = (T3x4, T3x1)
485485
Cls = CoupleInputsDynamicShapes
486-
self.assertEqual([], Cls(args, {}, (ds_batch, ds_batch)).invalid_paths())
487-
self.assertEqual([(1, "[1]")], Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths())
486+
self.assertEqual(None, Cls(args, {}, (ds_batch, ds_batch)).invalid_paths())
487+
self.assertEqual(
488+
(None, {1: "d=[1]"}), Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths()
489+
)
488490

489491
def test_couple_input_ds_2(self):
490492
T3x1 = torch.rand((3, 1))
@@ -493,9 +495,10 @@ def test_couple_input_ds_2(self):
493495
ds_batch_seq = {0: "batch", 1: "seq"}
494496
kwargs = {"A": T3x4, "B": T3x1}
495497
Cls = CoupleInputsDynamicShapes
496-
self.assertEqual([], Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths())
498+
self.assertEqual(None, Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths())
497499
self.assertEqual(
498-
[("B", "[1]")], Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths()
500+
{"B": {1: "d=[1]"}},
501+
Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths(),
499502
)
500503

501504
def test_couple_input_ds_3(self):
@@ -506,10 +509,10 @@ def test_couple_input_ds_3(self):
506509
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
507510
Cls = CoupleInputsDynamicShapes
508511
self.assertEqual(
509-
[], Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
512+
None, Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
510513
)
511514
self.assertEqual(
512-
[("B", 1, "[1]")],
515+
{"B": (None, {1: "d=[1]"})},
513516
Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
514517
)
515518

@@ -532,15 +535,15 @@ def test_couple_input_ds_cache(self):
532535
Cls = CoupleInputsDynamicShapes
533536
with bypass_export_some_errors(patch_transformers=True):
534537
self.assertEqual(
535-
[],
538+
None,
536539
Cls(
537540
(),
538541
kwargs,
539542
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
540543
).invalid_paths(),
541544
)
542545
self.assertEqual(
543-
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
546+
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
544547
Cls(
545548
(),
546549
kwargs,
@@ -561,16 +564,16 @@ def test_couple_input_ds_args_kwargs_0(self):
561564
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
562565
Cls = CoupleInputsDynamicShapes
563566
self.assertEqual(
564-
[], Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
567+
None, Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
565568
)
566569
self.assertEqual(
567-
[],
570+
None,
568571
Cls(
569572
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"]
570573
).invalid_paths(),
571574
)
572575
self.assertEqual(
573-
[("B", 1, "[1]")],
576+
{"B": (None, {1: "d=[1]"})},
574577
Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
575578
)
576579

@@ -584,7 +587,7 @@ def test_couple_input_ds_args_kwargs_1(self):
584587
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
585588
Cls = CoupleInputsDynamicShapes
586589
self.assertEqual(
587-
[],
590+
None,
588591
Cls(
589592
args,
590593
kwargs,
@@ -593,7 +596,7 @@ def test_couple_input_ds_args_kwargs_1(self):
593596
).invalid_paths(),
594597
)
595598
self.assertEqual(
596-
[("X", "[1]"), ("B", 1, "[1]")],
599+
{"X": {1: "d=[1]"}, "B": (None, {1: "d=[1]"})},
597600
Cls(
598601
args,
599602
kwargs,

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 71 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def __str__(self) -> str:
488488
]
489489
)
490490

491-
def invalid_paths(self) -> List[Union[str, int]]:
491+
def invalid_paths(self) -> Any:
492492
"""
493493
Tells the inputs are valid based on the dynamic shapes definition.
494494
The method assumes that all custom classes can be serialized.
@@ -498,18 +498,42 @@ def invalid_paths(self) -> List[Union[str, int]]:
498498
The function checks that a dynamic dimension does not receive a value
499499
of 0 or 1. It returns a list of invalid path.
500500
"""
501+
return self._generic_walker(self._valid_shapes_tensor)
502+
503+
@classmethod
504+
def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable:
505+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
506+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
507+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
508+
f"a dictionary is expected to specify a dimension dimension"
509+
)
510+
issues = {}
511+
for i, d in enumerate(inputs.shape):
512+
if i in ds and not isinstance(ds[i], int):
513+
# dynamic then
514+
if d in {0, 1}:
515+
# export issues for sure
516+
issues[i] = f"d=[{d}]"
517+
return issues if issues else None
518+
519+
def _generic_walker(self, method_to_call: Callable) -> Any:
520+
"""
521+
Generic deserializator walking through inputs and dynamic_shapes all along.
522+
The function returns a result with the same structure as the dynamic shapes.
523+
"""
501524
if not self.args:
502525
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
503526
f"Type mismatch, args={string_type(self.args)} and "
504527
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
505528
)
506-
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
529+
return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes)
530+
507531
if not self.kwargs:
508532
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
509533
f"Type mismatch, args={string_type(self.args)} and "
510534
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
511535
)
512-
return list(self._valid_shapes(self.args, self.dynamic_shapes))
536+
return self._generic_walker_step(method_to_call, self.args, self.dynamic_shapes)
513537

514538
assert isinstance(self.dynamic_shapes, dict), (
515539
f"Both positional and named arguments (args and kwargs) are filled. "
@@ -519,12 +543,14 @@ def invalid_paths(self) -> List[Union[str, int]]:
519543
self.dynamic_shapes
520544
):
521545
# No dynamic shapes for the positional arguments.
522-
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
546+
return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes)
523547

524548
if isinstance(self.args_names, list):
525549
if not set(self.args_names) & set(self.dynamic_shapes):
526550
# No dynamic shapes for the positional arguments.
527-
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
551+
return self._generic_walker_step(
552+
method_to_call, self.kwargs, self.dynamic_shapes
553+
)
528554

529555
assert self.args_names, (
530556
"args and kwargs are filled, then args_names must be specified in "
@@ -537,62 +563,52 @@ def invalid_paths(self) -> List[Union[str, int]]:
537563
)
538564
kwargs = dict(zip(self.args_names, self.args))
539565
kwargs.update(self.kwargs)
540-
return list(self._valid_shapes(kwargs, self.dynamic_shapes))
566+
return self._generic_walker_step(method_to_call, kwargs, self.dynamic_shapes)
541567

542568
raise NotImplementedError(
543569
f"Not yet implemented when args is filled, "
544570
f"kwargs as well but args_names is {type(self.args_names)}"
545571
)
546572

547573
@classmethod
548-
def _valid_shapes(
549-
cls, inputs: Any, ds: Any, prefix: Tuple[Union[int, str], ...] = ()
550-
) -> Iterable:
551-
assert all(isinstance(i, (int, str)) for i in prefix), f"Unexpected prefix {prefix}"
574+
def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> Iterable:
552575
if isinstance(inputs, torch.Tensor):
553-
assert isinstance(ds, dict) and all(
554-
isinstance(s, int) for s in ds
555-
), f"Unexpected types, inputs is a Tensor but ds={ds}, prefix={prefix}"
556-
for i, d in enumerate(inputs.shape):
557-
if i in ds and not isinstance(ds[i], int):
558-
# dynamic then
559-
if d in {0, 1}:
560-
# export issues for sure
561-
yield (*prefix, f"[{i}]")
562-
else:
563-
if isinstance(inputs, (int, float, str)):
564-
pass
565-
elif isinstance(inputs, (tuple, list, dict)):
566-
assert type(ds) is type(inputs), (
567-
f"Type mismatch between inputs {type(inputs)} "
568-
f"and ds={type(ds)}, prefix={prefix!r}"
569-
)
570-
assert len(ds) == len(inputs), (
571-
f"Length mismatch between inputs {len(inputs)} "
572-
f"and ds={len(ds)}, prefix={prefix!r}\n"
573-
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
574-
)
575-
if isinstance(inputs, (tuple, list)):
576-
for ind, (i, d) in enumerate(zip(inputs, ds)):
577-
for path in cls._valid_shapes(i, d, prefix=(*prefix, ind)):
578-
yield path
579-
else:
580-
assert set(inputs) == set(ds), (
581-
f"Keys mismatch between inputs {set(inputs)} "
582-
f"and ds={set(ds)}, prefix={prefix!r}"
583-
)
584-
for k, v in inputs.items():
585-
for path in cls._valid_shapes(v, ds[k], prefix=(*prefix, k)):
586-
yield path
587-
else:
588-
# A custom class.
589-
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
590-
f"Class {inputs.__class__.__name__!r} was not registered using "
591-
f"torch.utils._pytree.register_pytree_node, it is not possible to "
592-
f"map this class with the given dynamic shapes."
576+
return method_to_call(inputs, ds)
577+
if isinstance(inputs, (int, float, str)):
578+
return None
579+
if isinstance(inputs, (tuple, list, dict)):
580+
assert type(ds) is type(
581+
inputs
582+
), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}"
583+
assert len(ds) == len(inputs), (
584+
f"Length mismatch between inputs {len(inputs)} "
585+
f"and ds={len(ds)}\n"
586+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
587+
)
588+
if isinstance(inputs, (tuple, list)):
589+
value = []
590+
for i, d in zip(inputs, ds):
591+
value.append(cls._generic_walker_step(method_to_call, i, d))
592+
return (
593+
(value if isinstance(ds, list) else tuple(value))
594+
if any(v is not None for v in value)
595+
else None
593596
)
594-
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
595-
for path in cls._valid_shapes(
596-
flat, ds, prefix=(*prefix, inputs.__class__.__name__)
597-
):
598-
yield path
597+
assert set(inputs) == set(
598+
ds
599+
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
600+
dvalue = {}
601+
for k, v in inputs.items():
602+
t = cls._generic_walker_step(method_to_call, v, ds[k])
603+
if t is not None:
604+
dvalue[k] = t
605+
return dvalue if dvalue else None
606+
607+
# A custom class.
608+
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
609+
f"Class {inputs.__class__.__name__!r} was not registered using "
610+
f"torch.utils._pytree.register_pytree_node, it is not possible to "
611+
f"map this class with the given dynamic shapes."
612+
)
613+
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
614+
return cls._generic_walker_step(method_to_call, flat, ds)

0 commit comments

Comments
 (0)