Skip to content

Commit c058fd1

Browse files
committed
rename
1 parent 6e1c7e6 commit c058fd1

File tree

3 files changed

+75
-29
lines changed

3 files changed

+75
-29
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -459,21 +459,31 @@ def test_couple_input_ds_0(self):
459459
T3x4 = torch.rand((3, 4))
460460
T3x1 = torch.rand((3, 1))
461461
Cls = CoupleInputsDynamicShapes
462-
self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_paths())
463-
self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_paths())
464-
self.assertEmpty(Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_paths())
465-
self.assertEmpty(Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_paths())
462+
self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export())
463+
self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export())
464+
self.assertEmpty(
465+
Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export()
466+
)
467+
self.assertEmpty(
468+
Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export()
469+
)
466470

467471
T1x4 = torch.rand((1, 4))
468472
T1x1 = torch.rand((1, 1))
469473
Cls = CoupleInputsDynamicShapes
470-
self.assertEqual(({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths())
471-
self.assertEqual(({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths())
472474
self.assertEqual(
473-
{"A": {0: "d=[1]"}}, Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths()
475+
({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export()
474476
)
475477
self.assertEqual(
476-
{"A": {0: "d=[1]"}}, Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths()
478+
({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export()
479+
)
480+
self.assertEqual(
481+
{"A": {0: "d=[1]"}},
482+
Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export(),
483+
)
484+
self.assertEqual(
485+
{"A": {0: "d=[1]"}},
486+
Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export(),
477487
)
478488

479489
def test_couple_input_ds_1(self):
@@ -483,9 +493,12 @@ def test_couple_input_ds_1(self):
483493
ds_batch_seq = {0: "batch", 1: "seq"}
484494
args = (T3x4, T3x1)
485495
Cls = CoupleInputsDynamicShapes
486-
self.assertEqual(None, Cls(args, {}, (ds_batch, ds_batch)).invalid_paths())
487496
self.assertEqual(
488-
(None, {1: "d=[1]"}), Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths()
497+
None, Cls(args, {}, (ds_batch, ds_batch)).invalid_dimensions_for_export()
498+
)
499+
self.assertEqual(
500+
(None, {1: "d=[1]"}),
501+
Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_dimensions_for_export(),
489502
)
490503

491504
def test_couple_input_ds_2(self):
@@ -495,10 +508,15 @@ def test_couple_input_ds_2(self):
495508
ds_batch_seq = {0: "batch", 1: "seq"}
496509
kwargs = {"A": T3x4, "B": T3x1}
497510
Cls = CoupleInputsDynamicShapes
498-
self.assertEqual(None, Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths())
511+
self.assertEqual(
512+
None,
513+
Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_dimensions_for_export(),
514+
)
499515
self.assertEqual(
500516
{"B": {1: "d=[1]"}},
501-
Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths(),
517+
Cls(
518+
(), kwargs, {"A": ds_batch, "B": ds_batch_seq}
519+
).invalid_dimensions_for_export(),
502520
)
503521

504522
def test_couple_input_ds_3(self):
@@ -509,11 +527,16 @@ def test_couple_input_ds_3(self):
509527
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
510528
Cls = CoupleInputsDynamicShapes
511529
self.assertEqual(
512-
None, Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
530+
None,
531+
Cls(
532+
(), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}
533+
).invalid_dimensions_for_export(),
513534
)
514535
self.assertEqual(
515536
{"B": (None, {1: "d=[1]"})},
516-
Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
537+
Cls(
538+
(), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
539+
).invalid_dimensions_for_export(),
517540
)
518541

519542
def test_couple_input_ds_cache(self):
@@ -540,7 +563,7 @@ def test_couple_input_ds_cache(self):
540563
(),
541564
kwargs,
542565
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
543-
).invalid_paths(),
566+
).invalid_dimensions_for_export(),
544567
)
545568
self.assertEqual(
546569
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
@@ -551,7 +574,7 @@ def test_couple_input_ds_cache(self):
551574
"A": ds_batch,
552575
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
553576
},
554-
).invalid_paths(),
577+
).invalid_dimensions_for_export(),
555578
)
556579

557580
def test_couple_input_ds_args_kwargs_0(self):
@@ -564,17 +587,22 @@ def test_couple_input_ds_args_kwargs_0(self):
564587
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
565588
Cls = CoupleInputsDynamicShapes
566589
self.assertEqual(
567-
None, Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
590+
None,
591+
Cls(
592+
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}
593+
).invalid_dimensions_for_export(),
568594
)
569595
self.assertEqual(
570596
None,
571597
Cls(
572598
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"]
573-
).invalid_paths(),
599+
).invalid_dimensions_for_export(),
574600
)
575601
self.assertEqual(
576602
{"B": (None, {1: "d=[1]"})},
577-
Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
603+
Cls(
604+
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
605+
).invalid_dimensions_for_export(),
578606
)
579607

580608
def test_couple_input_ds_args_kwargs_1(self):
@@ -593,7 +621,7 @@ def test_couple_input_ds_args_kwargs_1(self):
593621
kwargs,
594622
{"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)},
595623
args_names=["X"],
596-
).invalid_paths(),
624+
).invalid_dimensions_for_export(),
597625
)
598626
self.assertEqual(
599627
{"X": {1: "d=[1]"}, "B": (None, {1: "d=[1]"})},
@@ -602,7 +630,7 @@ def test_couple_input_ds_args_kwargs_1(self):
602630
kwargs,
603631
{"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)},
604632
args_names=["X"],
605-
).invalid_paths(),
633+
).invalid_dimensions_for_export(),
606634
)
607635

608636
def test_couple_input_ds_replace_string(self):

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(self, x, y):
7171
ds = mi.guess_dynamic_shapes()
7272
pprint.pprint(ds)
7373
74-
**and and kwargs**
74+
**args and kwargs**
7575
7676
.. runpython::
7777
:showcode:
@@ -449,7 +449,10 @@ def validate_inputs_for_export(
449449
if len(self.inputs) == 1:
450450
return []
451451
dyn_shapes = self.guess_dynamic_shapes()
452-
return [CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_paths() for i in self.inputs]
452+
return [
453+
CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export()
454+
for i in self.inputs
455+
]
453456

454457

455458
class CoupleInputsDynamicShapes:
@@ -530,15 +533,16 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
530533
new_ds[i] = value
531534
return new_ds
532535

533-
def invalid_paths(self):
536+
def invalid_dimensions_for_export(self):
534537
"""
535538
Tells if the inputs are valid based on the dynamic shapes definition.
536539
The method assumes that all custom classes can be serialized.
537540
If some patches were applied to export, they should enabled while
538541
calling this method if the inputs contains such classes.
539542
540543
The function checks that a dynamic dimension does not receive a value
541-
of 0 or 1. It returns a list of invalid path.
544+
of 0 or 1. It returns the unexpected values in the same structure as
545+
the given dynamic shapes.
542546
543547
Example:
544548
@@ -554,7 +558,23 @@ def invalid_paths(self):
554558
ds_batch_seq = {0: "batch", 1: "seq"}
555559
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
556560
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
557-
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_paths())
561+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
562+
563+
In case it works, it shows:
564+
565+
.. runpython::
566+
:showcode:
567+
568+
import torch
569+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
570+
571+
T3x2 = torch.rand((3, 2))
572+
T3x4 = torch.rand((3, 4))
573+
ds_batch = {0: "batch"}
574+
ds_batch_seq = {0: "batch", 1: "seq"}
575+
kwargs = {"A": T3x4, "B": (T3x2, T3x2)}
576+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
577+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
558578
"""
559579
return self._generic_walker(self._valid_shapes_tensor)
560580

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,7 @@ def validate_model(
276276
)
277277
if verbose:
278278
print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
279-
print(
280-
f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}"
281-
)
279+
print(f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}")
282280

283281
if not empty(dtype):
284282
if isinstance(dtype, str):

0 commit comments

Comments
 (0)