Skip to content

Commit 5cd902b

Browse files
committed
valid inputs
1 parent 45183ba commit 5cd902b

File tree

3 files changed

+222
-8
lines changed

3 files changed

+222
-8
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from onnx_diagnostic.helpers import string_type
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
66
from onnx_diagnostic.export import ModelInputs
7+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
78

89

910
class TestDynamicShapes(ExtTestCase):
@@ -453,6 +454,101 @@ def forward(self, cache, z):
453454
),
454455
)
455456

457+
def test_couple_input_ds_0(self):
458+
T3x4 = torch.rand((3, 4))
459+
T3x1 = torch.rand((3, 1))
460+
Cls = CoupleInputsDynamicShapes
461+
self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_paths())
462+
self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_paths())
463+
self.assertEmpty(Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_paths())
464+
self.assertEmpty(Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_paths())
465+
466+
T1x4 = torch.rand((1, 4))
467+
T1x1 = torch.rand((1, 1))
468+
Cls = CoupleInputsDynamicShapes
469+
self.assertEqual([(0, "[0]")], Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths())
470+
self.assertEqual([(0, "[0]")], Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths())
471+
self.assertEqual(
472+
[("A", "[0]")], Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths()
473+
)
474+
self.assertEqual(
475+
[("A", "[0]")], Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths()
476+
)
477+
478+
def test_couple_input_ds_1(self):
479+
T3x1 = torch.rand((3, 1))
480+
T3x4 = torch.rand((3, 4))
481+
ds_batch = {0: "batch"}
482+
ds_batch_seq = {0: "batch", 1: "seq"}
483+
args = (T3x4, T3x1)
484+
Cls = CoupleInputsDynamicShapes
485+
self.assertEqual([], Cls(args, {}, (ds_batch, ds_batch)).invalid_paths())
486+
self.assertEqual([(1, "[1]")], Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths())
487+
488+
def test_couple_input_ds_2(self):
489+
T3x1 = torch.rand((3, 1))
490+
T3x4 = torch.rand((3, 4))
491+
ds_batch = {0: "batch"}
492+
ds_batch_seq = {0: "batch", 1: "seq"}
493+
kwargs = {"A": T3x4, "B": T3x1}
494+
Cls = CoupleInputsDynamicShapes
495+
self.assertEqual([], Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths())
496+
self.assertEqual(
497+
[("B", "[1]")], Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths()
498+
)
499+
500+
def test_couple_input_ds_3(self):
501+
T3x1 = torch.rand((3, 1))
502+
T3x4 = torch.rand((3, 4))
503+
ds_batch = {0: "batch"}
504+
ds_batch_seq = {0: "batch", 1: "seq"}
505+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
506+
Cls = CoupleInputsDynamicShapes
507+
self.assertEqual(
508+
[], Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
509+
)
510+
self.assertEqual(
511+
[("B", 1, "[1]")],
512+
Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
513+
)
514+
515+
def test_couple_input_ds_cache(self):
516+
T3x1 = torch.rand((3, 1))
517+
T3x4 = torch.rand((3, 4))
518+
ds_batch = {0: "batch"}
519+
ds_batch_seq = {0: "batch", 2: "seq"}
520+
521+
n_layers = 2
522+
bsize, nheads, slen, dim = 2, 4, 1, 7
523+
cache = make_dynamic_cache(
524+
[
525+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
526+
for i in range(n_layers)
527+
]
528+
)
529+
530+
kwargs = {"A": T3x4, "B": (T3x1, cache)}
531+
Cls = CoupleInputsDynamicShapes
532+
self.assertEqual(
533+
[],
534+
Cls(
535+
(),
536+
kwargs,
537+
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
538+
).invalid_paths(),
539+
)
540+
self.assertEqual(
541+
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
542+
Cls(
543+
(),
544+
kwargs,
545+
{
546+
"A": ds_batch,
547+
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
548+
},
549+
).invalid_paths(),
550+
)
551+
456552

457553
if __name__ == "__main__":
458554
unittest.main(verbosity=2)

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import unittest
2-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
2+
from onnx_diagnostic.ext_test_case import (
3+
ExtTestCase,
4+
hide_stdout,
5+
ignore_warnings,
6+
ignore_errors,
7+
)
38
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
49
from onnx_diagnostic.torch_onnx.sbs import run_aligned
510

@@ -13,6 +18,7 @@ class TestSideBySide(ExtTestCase):
1318

1419
@hide_stdout()
1520
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
21+
@ignore_errors(OSError, "connectivity issues")
1622
@ignore_warnings((UserWarning,))
1723
def test_ep_onnx_sync_exp(self):
1824
import torch

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import inspect
2-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
33
import numpy as np
44
import torch
55
from ..helpers import string_type
66

7+
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
8+
79

810
class ModelInputs:
911
"""
@@ -218,7 +220,7 @@ def process_inputs(
218220
return new_inputs
219221

220222
@property
221-
def true_model_name(self):
223+
def true_model_name(self) -> str:
222224
"Returns class name or module name."
223225
return (
224226
self.model.__class__.__name__
@@ -227,7 +229,7 @@ def true_model_name(self):
227229
)
228230

229231
@property
230-
def full_name(self):
232+
def full_name(self) -> str:
231233
"Returns a name and class name."
232234
if self.method_name == "forward":
233235
return f"{self.name}:{self.true_model_name}"
@@ -337,9 +339,7 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
337339
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}"
338340
)
339341

340-
def guess_dynamic_shapes(
341-
self,
342-
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
342+
def guess_dynamic_shapes(self) -> DYNAMIC_SHAPES:
343343
"""
344344
Guesses the dynamic shapes for that module from two execution.
345345
If there is only one execution, then that would be static dimensions.
@@ -386,7 +386,7 @@ def move_to_kwargs(
386386
args: Tuple[Any, ...],
387387
kwargs: Dict[str, Any],
388388
dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
389-
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Tuple[Tuple[Any, ...], Dict[str, Any]]]:
389+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], DYNAMIC_SHAPES]:
390390
"""
391391
Uses the signatures to move positional arguments (args) to named arguments (kwargs)
392392
with the corresponding dynamic shapes.
@@ -434,3 +434,115 @@ def move_to_kwargs(
434434
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
435435
)
436436
return args, kwargs, (tuple(), kw_dyn)
437+
438+
def validate_inputs_for_export(
439+
self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
440+
) -> List[List[str]]:
441+
"""
442+
Validates the inputs the class contains for the given dynamic shapes.
443+
If not specified, the dynamic_shapes are guessed.
444+
445+
:param dynamic_shapes: dynamic shapes to validate
446+
:return: a list of lists, every list contains the path the invalid dimension
447+
"""
448+
if dynamic_shapes is None:
449+
if len(self.inputs) == 1:
450+
return True
451+
dyn_shapes = self.guess_dynamic_shapes()
452+
return [CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_paths() for i in self.inputs]
453+
454+
455+
class CoupleInputsDynamicShapes:
456+
"""
457+
Pair inputs / dynamic shapes.
458+
"""
459+
460+
def __init__(
461+
self, args: Tuple[Any, ...], kwargs: Dict[str, Any], dynamic_shapes: DYNAMIC_SHAPES
462+
):
463+
self.args = args
464+
self.kwargs = kwargs
465+
self.dynamic_shapes = dynamic_shapes
466+
467+
def __str__(self) -> str:
468+
return "\n".join(
469+
[
470+
f"{self.__class__.__name__}(",
471+
f" args={string_type(self.args, with_shape=True)},"
472+
f" kwargs={string_type(self.kwargs, with_shape=True)},"
473+
f" dynamic_shapes={string_type(self.dynamic_shapes, with_shape=True)},"
474+
f")",
475+
]
476+
)
477+
478+
def invalid_paths(self) -> List[Union[str, int]]:
479+
"""
480+
Tells the inputs are valid based on the dynamic shapes definition.
481+
The method assumes that all custom classes can be serialized.
482+
If some patches were applied to export, they should enabled while
483+
calling this method if the inputs contains such classes.
484+
485+
The function checks that a dynamic dimension does not receive a value
486+
of 0 or 1. It returns a list of invalid path.
487+
"""
488+
if not self.args:
489+
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
490+
f"Type mismatch, args={string_type(self.args)} and "
491+
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
492+
)
493+
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
494+
if not self.kwargs:
495+
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
496+
f"Type mismatch, args={string_type(self.args)} and "
497+
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
498+
)
499+
return list(self._valid_shapes(self.args, self.dynamic_shapes))
500+
raise NotImplementedError("args and kwargs are filled, it is not implemented yet.")
501+
502+
@classmethod
503+
def _valid_shapes(
504+
cls, inputs: Any, ds: Any, prefix: Tuple[Union[int, str], ...] = ()
505+
) -> Iterable:
506+
assert all(isinstance(i, (int, str)) for i in prefix), f"Unexpected prefix {prefix}"
507+
if isinstance(inputs, torch.Tensor):
508+
assert isinstance(ds, dict) and all(
509+
isinstance(s, int) for s in ds
510+
), f"Unexpected types, inputs is a Tensor but ds={ds}, prefix={prefix}"
511+
for i, d in enumerate(inputs.shape):
512+
if i in ds and not isinstance(ds[i], int):
513+
# dynamic then
514+
if d in {0, 1}:
515+
# export issues for sure
516+
yield (*prefix, f"[{i}]")
517+
else:
518+
if isinstance(inputs, (int, float, str)):
519+
pass
520+
elif isinstance(inputs, (tuple, list, dict)):
521+
assert type(ds) is type(inputs), (
522+
f"Type mismatch between inputs {type(inputs)} "
523+
f"and ds={type(ds)}, prefix={prefix!r}"
524+
)
525+
assert len(ds) == len(inputs), (
526+
f"Length mismatch between inputs {len(inputs)} "
527+
f"and ds={len(ds)}, prefix={prefix!r}\n"
528+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
529+
)
530+
if isinstance(inputs, (tuple, list)):
531+
for ind, (i, d) in enumerate(zip(inputs, ds)):
532+
for path in cls._valid_shapes(i, d, prefix=(*prefix, ind)):
533+
yield path
534+
else:
535+
assert set(inputs) == set(ds), (
536+
f"Keys mismatch between inputs {set(inputs)} "
537+
f"and ds={set(ds)}, prefix={prefix!r}"
538+
)
539+
for k, v in inputs.items():
540+
for path in cls._valid_shapes(v, ds[k], prefix=(*prefix, k)):
541+
yield path
542+
else:
543+
# A custom class.
544+
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
545+
for path in cls._valid_shapes(
546+
flat, ds, prefix=(*prefix, inputs.__class__.__name__)
547+
):
548+
yield path

0 commit comments

Comments
 (0)