Skip to content

Commit 00ad63e

Browse files
authored
refactor algorithm to validate shapes (#37)
* refactor algorithm to validate shapes * issues * add change dynamic * fix documentation * rename
1 parent db7abe6 commit 00ad63e

File tree

5 files changed

+359
-93
lines changed

5 files changed

+359
-93
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 103 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -459,21 +459,31 @@ def test_couple_input_ds_0(self):
459459
T3x4 = torch.rand((3, 4))
460460
T3x1 = torch.rand((3, 1))
461461
Cls = CoupleInputsDynamicShapes
462-
self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_paths())
463-
self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_paths())
464-
self.assertEmpty(Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_paths())
465-
self.assertEmpty(Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_paths())
462+
self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export())
463+
self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export())
464+
self.assertEmpty(
465+
Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export()
466+
)
467+
self.assertEmpty(
468+
Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export()
469+
)
466470

467471
T1x4 = torch.rand((1, 4))
468472
T1x1 = torch.rand((1, 1))
469473
Cls = CoupleInputsDynamicShapes
470-
self.assertEqual([(0, "[0]")], Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths())
471-
self.assertEqual([(0, "[0]")], Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths())
472474
self.assertEqual(
473-
[("A", "[0]")], Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths()
475+
({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export()
476+
)
477+
self.assertEqual(
478+
({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export()
479+
)
480+
self.assertEqual(
481+
{"A": {0: "d=[1]"}},
482+
Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export(),
474483
)
475484
self.assertEqual(
476-
[("A", "[0]")], Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths()
485+
{"A": {0: "d=[1]"}},
486+
Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export(),
477487
)
478488

479489
def test_couple_input_ds_1(self):
@@ -483,8 +493,13 @@ def test_couple_input_ds_1(self):
483493
ds_batch_seq = {0: "batch", 1: "seq"}
484494
args = (T3x4, T3x1)
485495
Cls = CoupleInputsDynamicShapes
486-
self.assertEqual([], Cls(args, {}, (ds_batch, ds_batch)).invalid_paths())
487-
self.assertEqual([(1, "[1]")], Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths())
496+
self.assertEqual(
497+
None, Cls(args, {}, (ds_batch, ds_batch)).invalid_dimensions_for_export()
498+
)
499+
self.assertEqual(
500+
(None, {1: "d=[1]"}),
501+
Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_dimensions_for_export(),
502+
)
488503

489504
def test_couple_input_ds_2(self):
490505
T3x1 = torch.rand((3, 1))
@@ -493,9 +508,15 @@ def test_couple_input_ds_2(self):
493508
ds_batch_seq = {0: "batch", 1: "seq"}
494509
kwargs = {"A": T3x4, "B": T3x1}
495510
Cls = CoupleInputsDynamicShapes
496-
self.assertEqual([], Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths())
497511
self.assertEqual(
498-
[("B", "[1]")], Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths()
512+
None,
513+
Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_dimensions_for_export(),
514+
)
515+
self.assertEqual(
516+
{"B": {1: "d=[1]"}},
517+
Cls(
518+
(), kwargs, {"A": ds_batch, "B": ds_batch_seq}
519+
).invalid_dimensions_for_export(),
499520
)
500521

501522
def test_couple_input_ds_3(self):
@@ -506,11 +527,16 @@ def test_couple_input_ds_3(self):
506527
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
507528
Cls = CoupleInputsDynamicShapes
508529
self.assertEqual(
509-
[], Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
530+
None,
531+
Cls(
532+
(), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}
533+
).invalid_dimensions_for_export(),
510534
)
511535
self.assertEqual(
512-
[("B", 1, "[1]")],
513-
Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
536+
{"B": (None, {1: "d=[1]"})},
537+
Cls(
538+
(), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
539+
).invalid_dimensions_for_export(),
514540
)
515541

516542
def test_couple_input_ds_cache(self):
@@ -532,23 +558,23 @@ def test_couple_input_ds_cache(self):
532558
Cls = CoupleInputsDynamicShapes
533559
with bypass_export_some_errors(patch_transformers=True):
534560
self.assertEqual(
535-
[],
561+
None,
536562
Cls(
537563
(),
538564
kwargs,
539565
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
540-
).invalid_paths(),
566+
).invalid_dimensions_for_export(),
541567
)
542568
self.assertEqual(
543-
[("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")],
569+
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
544570
Cls(
545571
(),
546572
kwargs,
547573
{
548574
"A": ds_batch,
549575
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
550576
},
551-
).invalid_paths(),
577+
).invalid_dimensions_for_export(),
552578
)
553579

554580
def test_couple_input_ds_args_kwargs_0(self):
@@ -561,17 +587,22 @@ def test_couple_input_ds_args_kwargs_0(self):
561587
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
562588
Cls = CoupleInputsDynamicShapes
563589
self.assertEqual(
564-
[], Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths()
590+
None,
591+
Cls(
592+
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}
593+
).invalid_dimensions_for_export(),
565594
)
566595
self.assertEqual(
567-
[],
596+
None,
568597
Cls(
569598
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"]
570-
).invalid_paths(),
599+
).invalid_dimensions_for_export(),
571600
)
572601
self.assertEqual(
573-
[("B", 1, "[1]")],
574-
Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(),
602+
{"B": (None, {1: "d=[1]"})},
603+
Cls(
604+
args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
605+
).invalid_dimensions_for_export(),
575606
)
576607

577608
def test_couple_input_ds_args_kwargs_1(self):
@@ -584,23 +615,67 @@ def test_couple_input_ds_args_kwargs_1(self):
584615
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
585616
Cls = CoupleInputsDynamicShapes
586617
self.assertEqual(
587-
[],
618+
None,
619+
Cls(
620+
args,
621+
kwargs,
622+
{"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)},
623+
args_names=["X"],
624+
).invalid_dimensions_for_export(),
625+
)
626+
self.assertEqual(
627+
{"X": {1: "d=[1]"}, "B": (None, {1: "d=[1]"})},
628+
Cls(
629+
args,
630+
kwargs,
631+
{"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)},
632+
args_names=["X"],
633+
).invalid_dimensions_for_export(),
634+
)
635+
636+
def test_couple_input_ds_replace_string(self):
637+
T3x1 = torch.rand((3, 1))
638+
T3x4 = torch.rand((3, 4))
639+
T5x1 = torch.rand((5, 1))
640+
ds_batch = {0: "batch"}
641+
ds_batch_seq = {0: "batch", 1: "seq"}
642+
args = (T5x1,)
643+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
644+
Cls = CoupleInputsDynamicShapes
645+
self.assertEqual(
646+
{"X": {0: "DYN"}, "A": {0: "DYN"}, "B": ({0: "DYN"}, {0: "DYN"})},
588647
Cls(
589648
args,
590649
kwargs,
591650
{"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)},
592651
args_names=["X"],
593-
).invalid_paths(),
652+
).replace_string_by(value="DYN"),
594653
)
595654
self.assertEqual(
596-
[("X", "[1]"), ("B", 1, "[1]")],
655+
{
656+
"A": {0: "DYN"},
657+
"B": ({0: "DYN"}, {0: "DYN", 1: "DYN"}),
658+
"X": {0: "DYN", 1: "DYN"},
659+
},
597660
Cls(
598661
args,
599662
kwargs,
600663
{"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)},
601664
args_names=["X"],
602-
).invalid_paths(),
665+
).replace_string_by(value="DYN"),
666+
)
667+
668+
def test_couple_input_ds_change_dynamic_dimensions(self):
669+
T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7))
670+
T29 = torch.arange(2 * 9).reshape((2, 9))
671+
inst = CoupleInputsDynamicShapes(
672+
(),
673+
{"A": T257, "B": T29},
674+
{"A": {0: "batch", 2: "last"}, "B": {0: "batch", 1: "seq"}},
603675
)
676+
new_input = inst.change_dynamic_dimensions()
677+
self.assertEqual((3, 5, 8), new_input["A"].shape)
678+
self.assertEqual((3, 10), new_input["B"].shape)
604679

605680

606681
if __name__ == "__main__":

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_get_untrained_model_with_inputs_tiny_gpt_neo(self):
6262
self.assertEqual((316712, 79178), (data["size"], data["n_weights"]))
6363

6464
@hide_stdout()
65+
@ignore_errors(OSError)
6566
def test_get_untrained_model_with_inputs_phi_2(self):
6667
mid = "microsoft/phi-2"
6768
data = get_untrained_model_with_inputs(mid, verbose=1)
@@ -84,6 +85,7 @@ def test_get_untrained_model_with_inputs_beit(self):
8485
self.assertIn((data["size"], data["n_weights"]), [(111448, 27862), (56880, 14220)])
8586

8687
@hide_stdout()
88+
@ignore_errors(OSError)
8789
def test_get_untrained_model_with_inputs_codellama(self):
8890
mid = "codellama/CodeLlama-7b-Python-hf"
8991
data = get_untrained_model_with_inputs(mid, verbose=1)

0 commit comments

Comments
 (0)