Skip to content

Commit defb403

Browse files
committed
support more cases
1 parent c3847d9 commit defb403

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,57 @@ def test_couple_input_ds_cache(self):
551551
).invalid_paths(),
552552
)
553553

554+
def test_couple_input_ds_args_kwargs_0(self):
555+
T3x1 = torch.rand((3, 1))
556+
T3x4 = torch.rand((3, 4))
557+
T5x6 = torch.rand((5, 6))
558+
ds_batch = {0: "batch"}
559+
ds_batch_seq = {0: "batch", 1: "seq"}
560+
args = (T5x6,)
561+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
562+
Cls = CoupleInputsDynamicShapes
563+
self.assertEqual(
564+
[], Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
565+
)
566+
self.assertEqual(
567+
[],
568+
Cls(
569+
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"]
570+
).invalid_paths(),
571+
)
572+
self.assertEqual(
573+
[("B", 1, "[1]")],
574+
Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
575+
)
576+
577+
def test_couple_input_ds_args_kwargs_1(self):
578+
T3x1 = torch.rand((3, 1))
579+
T3x4 = torch.rand((3, 4))
580+
T5x1 = torch.rand((5, 1))
581+
ds_batch = {0: "batch"}
582+
ds_batch_seq = {0: "batch", 1: "seq"}
583+
args = (T5x1,)
584+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
585+
Cls = CoupleInputsDynamicShapes
586+
self.assertEqual(
587+
[],
588+
Cls(
589+
args,
590+
kwargs,
591+
{"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)},
592+
args_names=["X"],
593+
).invalid_paths(),
594+
)
595+
self.assertEqual(
596+
[("X", "[1]"), ("B", 1, "[1]")],
597+
Cls(
598+
args,
599+
kwargs,
600+
{"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)},
601+
args_names=["X"],
602+
).invalid_paths(),
603+
)
604+
554605

555606
if __name__ == "__main__":
556607
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,14 +455,27 @@ def validate_inputs_for_export(
455455
class CoupleInputsDynamicShapes:
456456
"""
457457
Pair inputs / dynamic shapes.
458+
459+
:param args: positional arguments
460+
:param kwargs: named arguments
461+
:param dynamic_shapes: dynamic shapes
462+
:param args_names: if both args and kwargs are not empty, then
463+
dynamic shapes must be a dictionary, and positional must be added
464+
to the named arguments. Arguments names or a module must be given
465+
in that case.
458466
"""
459467

460468
def __init__(
461-
self, args: Tuple[Any, ...], kwargs: Dict[str, Any], dynamic_shapes: DYNAMIC_SHAPES
469+
self,
470+
args: Tuple[Any, ...],
471+
kwargs: Dict[str, Any],
472+
dynamic_shapes: DYNAMIC_SHAPES,
473+
args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
462474
):
463475
self.args = args
464476
self.kwargs = kwargs
465477
self.dynamic_shapes = dynamic_shapes
478+
self.args_names = args_names
466479

467480
def __str__(self) -> str:
468481
return "\n".join(
@@ -497,7 +510,39 @@ def invalid_paths(self) -> List[Union[str, int]]:
497510
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
498511
)
499512
return list(self._valid_shapes(self.args, self.dynamic_shapes))
500-
raise NotImplementedError("args and kwargs are filled, it is not implemented yet.")
513+
514+
assert isinstance(self.dynamic_shapes, dict), (
515+
f"Both positional and named arguments (args and kwargs) are filled. "
516+
f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
517+
)
518+
if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
519+
self.dynamic_shapes
520+
):
521+
# No dynamic shapes for the positional arguments.
522+
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
523+
524+
if isinstance(self.args_names, list):
525+
if not set(self.args_names) & set(self.dynamic_shapes):
526+
# No dynamic shapes for the positional arguments.
527+
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
528+
529+
assert self.args_names, (
530+
"args and kwargs are filled, then args_names must be specified in "
531+
"the constructor to move positional arguments to named arguments."
532+
)
533+
assert len(self.args) <= len(self.args_names), (
534+
f"There are {len(self.args)} positional arguments "
535+
f"but only {len(self.args_names)} names. "
536+
f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
537+
)
538+
kwargs = dict(zip(self.args_names, self.args))
539+
kwargs.update(self.kwargs)
540+
return list(self._valid_shapes(kwargs, self.dynamic_shapes))
541+
542+
raise NotImplementedError(
543+
f"Not yet implemented when args is filled, "
544+
f"kwargs as well but args_names is {type(self.args_names)}"
545+
)
501546

502547
@classmethod
503548
def _valid_shapes(

0 commit comments

Comments
 (0)