@@ -459,21 +459,31 @@ def test_couple_input_ds_0(self):
459459 T3x4 = torch .rand ((3 , 4 ))
460460 T3x1 = torch .rand ((3 , 1 ))
461461 Cls = CoupleInputsDynamicShapes
462- self .assertEmpty (Cls ((T3x4 ,), {}, ({0 : "batch" },)).invalid_paths ())
463- self .assertEmpty (Cls ((T3x1 ,), {}, ({0 : "batch" },)).invalid_paths ())
464- self .assertEmpty (Cls ((), {"A" : T3x1 }, {"A" : {0 : "batch" }}).invalid_paths ())
465- self .assertEmpty (Cls ((), {"A" : T3x4 }, {"A" : {0 : "batch" }}).invalid_paths ())
462+ self .assertEmpty (Cls ((T3x4 ,), {}, ({0 : "batch" },)).invalid_dimensions_for_export ())
463+ self .assertEmpty (Cls ((T3x1 ,), {}, ({0 : "batch" },)).invalid_dimensions_for_export ())
464+ self .assertEmpty (
465+ Cls ((), {"A" : T3x1 }, {"A" : {0 : "batch" }}).invalid_dimensions_for_export ()
466+ )
467+ self .assertEmpty (
468+ Cls ((), {"A" : T3x4 }, {"A" : {0 : "batch" }}).invalid_dimensions_for_export ()
469+ )
466470
467471 T1x4 = torch .rand ((1 , 4 ))
468472 T1x1 = torch .rand ((1 , 1 ))
469473 Cls = CoupleInputsDynamicShapes
470- self .assertEqual ([(0 , "[0]" )], Cls ((T1x4 ,), {}, ({0 : "batch" },)).invalid_paths ())
471- self .assertEqual ([(0 , "[0]" )], Cls ((T1x1 ,), {}, ({0 : "batch" },)).invalid_paths ())
472474 self .assertEqual (
473- [("A" , "[0]" )], Cls ((), {"A" : T1x1 }, {"A" : {0 : "batch" }}).invalid_paths ()
475+ ({0 : "d=[1]" },), Cls ((T1x4 ,), {}, ({0 : "batch" },)).invalid_dimensions_for_export ()
476+ )
477+ self .assertEqual (
478+ ({0 : "d=[1]" },), Cls ((T1x1 ,), {}, ({0 : "batch" },)).invalid_dimensions_for_export ()
479+ )
480+ self .assertEqual (
481+ {"A" : {0 : "d=[1]" }},
482+ Cls ((), {"A" : T1x1 }, {"A" : {0 : "batch" }}).invalid_dimensions_for_export (),
474483 )
475484 self .assertEqual (
476- [("A" , "[0]" )], Cls ((), {"A" : T1x4 }, {"A" : {0 : "batch" }}).invalid_paths ()
485+ {"A" : {0 : "d=[1]" }},
486+ Cls ((), {"A" : T1x4 }, {"A" : {0 : "batch" }}).invalid_dimensions_for_export (),
477487 )
478488
479489 def test_couple_input_ds_1 (self ):
@@ -483,8 +493,13 @@ def test_couple_input_ds_1(self):
483493 ds_batch_seq = {0 : "batch" , 1 : "seq" }
484494 args = (T3x4 , T3x1 )
485495 Cls = CoupleInputsDynamicShapes
486- self .assertEqual ([], Cls (args , {}, (ds_batch , ds_batch )).invalid_paths ())
487- self .assertEqual ([(1 , "[1]" )], Cls (args , {}, (ds_batch , ds_batch_seq )).invalid_paths ())
496+ self .assertEqual (
497+ None , Cls (args , {}, (ds_batch , ds_batch )).invalid_dimensions_for_export ()
498+ )
499+ self .assertEqual (
500+ (None , {1 : "d=[1]" }),
501+ Cls (args , {}, (ds_batch , ds_batch_seq )).invalid_dimensions_for_export (),
502+ )
488503
489504 def test_couple_input_ds_2 (self ):
490505 T3x1 = torch .rand ((3 , 1 ))
@@ -493,9 +508,15 @@ def test_couple_input_ds_2(self):
493508 ds_batch_seq = {0 : "batch" , 1 : "seq" }
494509 kwargs = {"A" : T3x4 , "B" : T3x1 }
495510 Cls = CoupleInputsDynamicShapes
496- self .assertEqual ([], Cls ((), kwargs , {"A" : ds_batch , "B" : ds_batch }).invalid_paths ())
497511 self .assertEqual (
498- [("B" , "[1]" )], Cls ((), kwargs , {"A" : ds_batch , "B" : ds_batch_seq }).invalid_paths ()
512+ None ,
513+ Cls ((), kwargs , {"A" : ds_batch , "B" : ds_batch }).invalid_dimensions_for_export (),
514+ )
515+ self .assertEqual (
516+ {"B" : {1 : "d=[1]" }},
517+ Cls (
518+ (), kwargs , {"A" : ds_batch , "B" : ds_batch_seq }
519+ ).invalid_dimensions_for_export (),
499520 )
500521
501522 def test_couple_input_ds_3 (self ):
@@ -506,11 +527,16 @@ def test_couple_input_ds_3(self):
506527 kwargs = {"A" : T3x4 , "B" : (T3x1 , T3x1 )}
507528 Cls = CoupleInputsDynamicShapes
508529 self .assertEqual (
509- [], Cls ((), kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch )}).invalid_paths ()
530+ None ,
531+ Cls (
532+ (), kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch )}
533+ ).invalid_dimensions_for_export (),
510534 )
511535 self .assertEqual (
512- [("B" , 1 , "[1]" )],
513- Cls ((), kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch_seq )}).invalid_paths (),
536+ {"B" : (None , {1 : "d=[1]" })},
537+ Cls (
538+ (), kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch_seq )}
539+ ).invalid_dimensions_for_export (),
514540 )
515541
516542 def test_couple_input_ds_cache (self ):
@@ -532,23 +558,23 @@ def test_couple_input_ds_cache(self):
532558 Cls = CoupleInputsDynamicShapes
533559 with bypass_export_some_errors (patch_transformers = True ):
534560 self .assertEqual (
535- [] ,
561+ None ,
536562 Cls (
537563 (),
538564 kwargs ,
539565 {"A" : ds_batch , "B" : (ds_batch , [ds_batch , ds_batch , ds_batch , ds_batch ])},
540- ).invalid_paths (),
566+ ).invalid_dimensions_for_export (),
541567 )
542568 self .assertEqual (
543- [( "B" , 1 , "DynamicCache" , 1 , "[2]" ), ( "B" , 1 , "DynamicCache" , 3 , "[2]" )] ,
569+ { "B" : ( None , [ None , { 2 : "d=[1]" }, None , { 2 : "d=[1]" }])} ,
544570 Cls (
545571 (),
546572 kwargs ,
547573 {
548574 "A" : ds_batch ,
549575 "B" : (ds_batch , [ds_batch , ds_batch_seq , ds_batch , ds_batch_seq ]),
550576 },
551- ).invalid_paths (),
577+ ).invalid_dimensions_for_export (),
552578 )
553579
554580 def test_couple_input_ds_args_kwargs_0 (self ):
@@ -561,17 +587,22 @@ def test_couple_input_ds_args_kwargs_0(self):
561587 kwargs = {"A" : T3x4 , "B" : (T3x1 , T3x1 )}
562588 Cls = CoupleInputsDynamicShapes
563589 self .assertEqual (
564- [], Cls (args , kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch )}).invalid_paths ()
590+ None ,
591+ Cls (
592+ args , kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch )}
593+ ).invalid_dimensions_for_export (),
565594 )
566595 self .assertEqual (
567- [] ,
596+ None ,
568597 Cls (
569598 args , kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch )}, args_names = ["X" ]
570- ).invalid_paths (),
599+ ).invalid_dimensions_for_export (),
571600 )
572601 self .assertEqual (
573- [("B" , 1 , "[1]" )],
574- Cls (args , kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch_seq )}).invalid_paths (),
602+ {"B" : (None , {1 : "d=[1]" })},
603+ Cls (
604+ args , kwargs , {"A" : ds_batch , "B" : (ds_batch , ds_batch_seq )}
605+ ).invalid_dimensions_for_export (),
575606 )
576607
577608 def test_couple_input_ds_args_kwargs_1 (self ):
@@ -584,23 +615,67 @@ def test_couple_input_ds_args_kwargs_1(self):
584615 kwargs = {"A" : T3x4 , "B" : (T3x1 , T3x1 )}
585616 Cls = CoupleInputsDynamicShapes
586617 self .assertEqual (
587- [],
618+ None ,
619+ Cls (
620+ args ,
621+ kwargs ,
622+ {"X" : ds_batch , "A" : ds_batch , "B" : (ds_batch , ds_batch )},
623+ args_names = ["X" ],
624+ ).invalid_dimensions_for_export (),
625+ )
626+ self .assertEqual (
627+ {"X" : {1 : "d=[1]" }, "B" : (None , {1 : "d=[1]" })},
628+ Cls (
629+ args ,
630+ kwargs ,
631+ {"X" : ds_batch_seq , "A" : ds_batch , "B" : (ds_batch , ds_batch_seq )},
632+ args_names = ["X" ],
633+ ).invalid_dimensions_for_export (),
634+ )
635+
636+ def test_couple_input_ds_replace_string (self ):
637+ T3x1 = torch .rand ((3 , 1 ))
638+ T3x4 = torch .rand ((3 , 4 ))
639+ T5x1 = torch .rand ((5 , 1 ))
640+ ds_batch = {0 : "batch" }
641+ ds_batch_seq = {0 : "batch" , 1 : "seq" }
642+ args = (T5x1 ,)
643+ kwargs = {"A" : T3x4 , "B" : (T3x1 , T3x1 )}
644+ Cls = CoupleInputsDynamicShapes
645+ self .assertEqual (
646+ {"X" : {0 : "DYN" }, "A" : {0 : "DYN" }, "B" : ({0 : "DYN" }, {0 : "DYN" })},
588647 Cls (
589648 args ,
590649 kwargs ,
591650 {"X" : ds_batch , "A" : ds_batch , "B" : (ds_batch , ds_batch )},
592651 args_names = ["X" ],
593- ).invalid_paths ( ),
652+ ).replace_string_by ( value = "DYN" ),
594653 )
595654 self .assertEqual (
596- [("X" , "[1]" ), ("B" , 1 , "[1]" )],
655+ {
656+ "A" : {0 : "DYN" },
657+ "B" : ({0 : "DYN" }, {0 : "DYN" , 1 : "DYN" }),
658+ "X" : {0 : "DYN" , 1 : "DYN" },
659+ },
597660 Cls (
598661 args ,
599662 kwargs ,
600663 {"X" : ds_batch_seq , "A" : ds_batch , "B" : (ds_batch , ds_batch_seq )},
601664 args_names = ["X" ],
602- ).invalid_paths (),
665+ ).replace_string_by (value = "DYN" ),
666+ )
667+
668+ def test_couple_input_ds_change_dynamic_dimensions (self ):
669+ T257 = torch .arange (2 * 5 * 7 ).reshape ((2 , 5 , 7 ))
670+ T29 = torch .arange (2 * 9 ).reshape ((2 , 9 ))
671+ inst = CoupleInputsDynamicShapes (
672+ (),
673+ {"A" : T257 , "B" : T29 },
674+ {"A" : {0 : "batch" , 2 : "last" }, "B" : {0 : "batch" , 1 : "seq" }},
603675 )
676+ new_input = inst .change_dynamic_dimensions ()
677+ self .assertEqual ((3 , 5 , 8 ), new_input ["A" ].shape )
678+ self .assertEqual ((3 , 10 ), new_input ["B" ].shape )
604679
605680
606681if __name__ == "__main__" :
0 commit comments