@@ -618,6 +618,94 @@ def loop_body_1(z, iv, x, y):
618618 )
619619 """
620620
621+ def test_broadcast_in_dim_1 (self ):
622+ class BadBroadcast (torch .nn .Module ):
623+ def forward (self , x ):
624+ shape = [x .shape [0 ], x .shape [1 ], 1 ]
625+ dims = [0 , 1 ]
626+ return torch .ops .prims .broadcast_in_dim .default (x , shape , dims )
627+
628+ x = torch .rand ((3 , 4 ), dtype = torch .float32 )
629+ expected = BadBroadcast ()(x )
630+ DYN = torch .export .Dim .DYNAMIC
631+ ds = ({0 : DYN , 1 : DYN },)
632+ for strict in [False , True ]:
633+ with self .subTest (strict = strict ):
634+ ep = torch .export .export (
635+ BadBroadcast (), (x ,), dynamic_shapes = ds , strict = strict
636+ )
637+ got = ep .module ()(x )
638+ self .assertEqualArray (expected , got )
639+ with torch_export_patches (patch_torch = True ):
640+ ep = torch .export .export (
641+ BadBroadcast (), (x ,), dynamic_shapes = ds , strict = strict
642+ )
643+ got = ep .module ()(x )
644+ self .assertEqualArray (expected , got )
645+
646+ def test_broadcast_in_dim_2 (self ):
647+ class BadBroadcast (torch .nn .Module ):
648+ def forward (self , x ):
649+ shape = [x .shape [0 ], 3 , 1 ]
650+ dims = [0 , 1 ]
651+ return torch .ops .prims .broadcast_in_dim .default (x , shape , dims )
652+
653+ x = torch .rand ((3 , 1 ), dtype = torch .float32 )
654+ expected = BadBroadcast ()(x )
655+ print (expected .shape , expected )
656+ DYN = torch .export .Dim .DYNAMIC
657+ ds = ({0 : DYN , 1 : DYN },)
658+ for strict in [False , True ]:
659+ with self .subTest (strict = strict ):
660+ with torch_export_patches (patch_torch = True ):
661+ ep = torch .export .export (
662+ BadBroadcast (), (x ,), dynamic_shapes = ds , strict = strict
663+ )
664+ got = ep .module ()(x )
665+ self .assertEqualArray (expected , got )
666+
667+ def test_broadcast_in_dim_3 (self ):
668+ class BadBroadcast (torch .nn .Module ):
669+ def forward (self , x ):
670+ shape = [3 , x .shape [1 ], 1 ]
671+ dims = [0 , 1 ]
672+ return torch .ops .prims .broadcast_in_dim .default (x , shape , dims )
673+
674+ x = torch .rand ((1 , 3 ), dtype = torch .float32 )
675+ expected = BadBroadcast ()(x )
676+ print (expected .shape , expected )
677+ DYN = torch .export .Dim .DYNAMIC
678+ ds = ({0 : DYN , 1 : DYN },)
679+ for strict in [False , True ]:
680+ with self .subTest (strict = strict ):
681+ with torch_export_patches (patch_torch = True ):
682+ ep = torch .export .export (
683+ BadBroadcast (), (x ,), dynamic_shapes = ds , strict = strict
684+ )
685+ got = ep .module ()(x )
686+ self .assertEqualArray (expected , got )
687+
688+ def test_broadcast_in_dim_5 (self ):
689+ class BadBroadcast (torch .nn .Module ):
690+ def forward (self , x ):
691+ shape = [1 , x .shape [1 ], 1 ]
692+ dims = [0 , 1 ]
693+ return torch .ops .prims .broadcast_in_dim .default (x , shape , dims )
694+
695+ x = torch .rand ((1 , 3 ), dtype = torch .float32 )
696+ expected = BadBroadcast ()(x )
697+ print (expected .shape , expected )
698+ DYN = torch .export .Dim .DYNAMIC
699+ ds = ({0 : DYN , 1 : DYN },)
700+ for strict in [False , True ]:
701+ with self .subTest (strict = strict ):
702+ with torch_export_patches (patch_torch = True ):
703+ ep = torch .export .export (
704+ BadBroadcast (), (x ,), dynamic_shapes = ds , strict = strict
705+ )
706+ got = ep .module ()(x )
707+ self .assertEqualArray (expected , got )
708+
621709
622710if __name__ == "__main__" :
623711 unittest .main (verbosity = 2 )
0 commit comments