Skip to content

Commit d06d27e

Browse files
committed
merge
2 parents 5cd902b + abbcc6b commit d06d27e

File tree

4 files changed

+177
-17
lines changed

4 files changed

+177
-17
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
66
from onnx_diagnostic.export import ModelInputs
77
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
8+
<<<<<<< HEAD
9+
=======
10+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
11+
>>>>>>> abbcc6bca05bb6d921592fe814b92567b36da435
812

913

1014
class TestDynamicShapes(ExtTestCase):
@@ -529,23 +533,75 @@ def test_couple_input_ds_cache(self):
529533

530534
kwargs = {"A": T3x4, "B": (T3x1, cache)}
531535
Cls = CoupleInputsDynamicShapes
536+
with bypass_export_some_errors(patch_transformers=True):
537+
self.assertEqual(
538+
[],
539+
Cls(
540+
(),
541+
kwargs,
542+
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
543+
).invalid_paths(),
544+
)
545+
self.assertEqual(
546+
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
547+
Cls(
548+
(),
549+
kwargs,
550+
{
551+
"A": ds_batch,
552+
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
553+
},
554+
).invalid_paths(),
555+
)
556+
557+
def test_couple_input_ds_args_kwargs_0(self):
558+
T3x1 = torch.rand((3, 1))
559+
T3x4 = torch.rand((3, 4))
560+
T5x6 = torch.rand((5, 6))
561+
ds_batch = {0: "batch"}
562+
ds_batch_seq = {0: "batch", 1: "seq"}
563+
args = (T5x6,)
564+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
565+
Cls = CoupleInputsDynamicShapes
566+
self.assertEqual(
567+
[], Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
568+
)
532569
self.assertEqual(
533570
[],
534571
Cls(
535-
(),
572+
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"]
573+
).invalid_paths(),
574+
)
575+
self.assertEqual(
576+
[("B", 1, "[1]")],
577+
Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
578+
)
579+
580+
def test_couple_input_ds_args_kwargs_1(self):
581+
T3x1 = torch.rand((3, 1))
582+
T3x4 = torch.rand((3, 4))
583+
T5x1 = torch.rand((5, 1))
584+
ds_batch = {0: "batch"}
585+
ds_batch_seq = {0: "batch", 1: "seq"}
586+
args = (T5x1,)
587+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
588+
Cls = CoupleInputsDynamicShapes
589+
self.assertEqual(
590+
[],
591+
Cls(
592+
args,
536593
kwargs,
537-
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
594+
{"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)},
595+
args_names=["X"],
538596
).invalid_paths(),
539597
)
540598
self.assertEqual(
541-
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
599+
[("X", "[1]"), ("B", 1, "[1]")],
542600
Cls(
543-
(),
601+
args,
544602
kwargs,
545-
{
546-
"A": ds_batch,
547-
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
548-
},
603+
{"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)},
604+
args_names=["X"],
549605
).invalid_paths(),
550606
)
551607

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ class TestSideBySide(ExtTestCase):
1818

1919
@hide_stdout()
2020
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
21+
<<<<<<< HEAD
2122
@ignore_errors(OSError, "connectivity issues")
23+
=======
24+
@ignore_errors(OSError) # connectivity issues
25+
>>>>>>> abbcc6bca05bb6d921592fe814b92567b36da435
2226
@ignore_warnings((UserWarning,))
2327
def test_ep_onnx_sync_exp(self):
2428
import torch
@@ -52,7 +56,6 @@ def forward(self, x):
5256

5357
@hide_stdout()
5458
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
55-
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
5659
def test_ep_onnx_sync_a(self):
5760
import torch
5861

@@ -69,9 +72,10 @@ def forward(self, x):
6972
ep = torch.export.export(
7073
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
7174
)
72-
onx = torch.onnx.export(
73-
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True
74-
).model_proto
75+
epo = torch.onnx.export(
76+
ep, (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True
77+
)
78+
onx = epo.model_proto
7579
results = list(
7680
run_aligned(
7781
ep,

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def move_to_kwargs(
437437

438438
def validate_inputs_for_export(
439439
self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
440-
) -> List[List[str]]:
440+
) -> List[List[Union[int, str]]]:
441441
"""
442442
Validates the inputs the class contains for the given dynamic shapes.
443443
If not specified, the dynamic_shapes are guessed.
@@ -447,22 +447,35 @@ def validate_inputs_for_export(
447447
"""
448448
if dynamic_shapes is None:
449449
if len(self.inputs) == 1:
450-
return True
450+
return []
451451
dyn_shapes = self.guess_dynamic_shapes()
452452
return [CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_paths() for i in self.inputs]
453453

454454

455455
class CoupleInputsDynamicShapes:
456456
"""
457457
Pair inputs / dynamic shapes.
458+
459+
:param args: positional arguments
460+
:param kwargs: named arguments
461+
:param dynamic_shapes: dynamic shapes
462+
:param args_names: if both args and kwargs are not empty, then
463+
dynamic shapes must be a dictionary, and positional must be added
464+
to the named arguments. Arguments names or a module must be given
465+
in that case.
458466
"""
459467

460468
def __init__(
461-
self, args: Tuple[Any, ...], kwargs: Dict[str, Any], dynamic_shapes: DYNAMIC_SHAPES
469+
self,
470+
args: Tuple[Any, ...],
471+
kwargs: Dict[str, Any],
472+
dynamic_shapes: DYNAMIC_SHAPES,
473+
args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
462474
):
463475
self.args = args
464476
self.kwargs = kwargs
465477
self.dynamic_shapes = dynamic_shapes
478+
self.args_names = args_names
466479

467480
def __str__(self) -> str:
468481
return "\n".join(
@@ -497,7 +510,39 @@ def invalid_paths(self) -> List[Union[str, int]]:
497510
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
498511
)
499512
return list(self._valid_shapes(self.args, self.dynamic_shapes))
500-
raise NotImplementedError("args and kwargs are filled, it is not implemented yet.")
513+
514+
assert isinstance(self.dynamic_shapes, dict), (
515+
f"Both positional and named arguments (args and kwargs) are filled. "
516+
f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
517+
)
518+
if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
519+
self.dynamic_shapes
520+
):
521+
# No dynamic shapes for the positional arguments.
522+
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
523+
524+
if isinstance(self.args_names, list):
525+
if not set(self.args_names) & set(self.dynamic_shapes):
526+
# No dynamic shapes for the positional arguments.
527+
return list(self._valid_shapes(self.kwargs, self.dynamic_shapes))
528+
529+
assert self.args_names, (
530+
"args and kwargs are filled, then args_names must be specified in "
531+
"the constructor to move positional arguments to named arguments."
532+
)
533+
assert len(self.args) <= len(self.args_names), (
534+
f"There are {len(self.args)} positional arguments "
535+
f"but only {len(self.args_names)} names. "
536+
f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
537+
)
538+
kwargs = dict(zip(self.args_names, self.args))
539+
kwargs.update(self.kwargs)
540+
return list(self._valid_shapes(kwargs, self.dynamic_shapes))
541+
542+
raise NotImplementedError(
543+
f"Not yet implemented when args is filled, "
544+
f"kwargs as well but args_names is {type(self.args_names)}"
545+
)
501546

502547
@classmethod
503548
def _valid_shapes(
@@ -541,6 +586,11 @@ def _valid_shapes(
541586
yield path
542587
else:
543588
# A custom class.
589+
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
590+
f"Class {inputs.__class__.__name__!r} was not registered using "
591+
f"torch.utils._pytree.register_pytree_node, it is not possible to "
592+
f"map this class with the given dynamic shapes."
593+
)
544594
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
545595
for path in cls._valid_shapes(
546596
flat, ds, prefix=(*prefix, inputs.__class__.__name__)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import inspect
23
import os
34
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -139,6 +140,55 @@ def _make_folder_name(
139140
return "-".join(els)
140141

141142

143+
def version_summary() -> Dict[str, Union[int, float, str]]:
144+
"""
145+
Example:
146+
147+
.. runpython::
148+
:showcode:
149+
150+
import pprint
151+
from onnx_diagnostic.torch_models.test_helper import version_summary
152+
153+
pprint.pprint(version_summary())
154+
"""
155+
import numpy
156+
157+
summary: Dict[str, Union[int, float, str]] = {
158+
"version_torch": torch.__version__,
159+
"version_numpy": numpy.__version__,
160+
}
161+
try:
162+
import transformers
163+
164+
summary["version_transformers"] = transformers.__version__
165+
except ImportError:
166+
pass
167+
try:
168+
import onnx
169+
170+
summary["version_onnx"] = onnx.__version__
171+
except ImportError:
172+
pass
173+
try:
174+
import onnxscript
175+
176+
summary["version_onnxscript"] = onnxscript.__version__
177+
except ImportError:
178+
pass
179+
try:
180+
import onnxruntime
181+
182+
summary["version_onnxruntime"] = onnxruntime.__version__
183+
except ImportError:
184+
pass
185+
import onnx_diagnostic
186+
187+
summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
188+
summary["version_date"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
189+
return summary
190+
191+
142192
def validate_model(
143193
model_id: str,
144194
task: Optional[str] = None,
@@ -180,7 +230,7 @@ def validate_model(
180230
another one with whatever the function produces
181231
"""
182232
assert not trained, f"trained={trained} not supported yet"
183-
summary: Dict[str, Union[int, float, str]] = {}
233+
summary = version_summary()
184234
if dump_folder:
185235
folder_name = _make_folder_name(
186236
model_id, exporter, optimization, dtype=dtype, device=device

0 commit comments

Comments
 (0)