@@ -982,16 +982,21 @@ def _greater_than_reduce(acc, x):
982982 elif guard_or_false (a .shape [original_idx ] != 1 ):
983983 new_strides .append (a .stride ()[original_idx ])
984984 else :
985+ # This checks generates the following issue:
986+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
987+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
988+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
989+ # original_idx=1
985990 torch ._check (
986991 a .shape [original_idx ] == shape [idx ],
987992 lambda idx = idx , original_idx = original_idx : (
988993 f"non-broadcasting semantics require "
989994 f"{ a .shape [original_idx ]} == { shape [idx ]} , "
990995 f"{ guard_or_false (a .shape [idx ] != 1 )} , "
991- f"guard_or_false(a.shape[idx] == 1)="
996+ f"guard_or_false(a.shape[idx]== 1)="
992997 f"{ guard_or_false (a .shape [idx ] == 1 )} , "
993- f"a.stride()={ a .stride ()} , idx={ idx } , "
994- f"original_idx={ original_idx } "
998+ f"a.stride()={ a .stride ()} , idx={ idx } , a.shape= { a . shape } , "
999+ f"shape= { shape } , original_idx={ original_idx } "
9951000 ),
9961001 )
9971002 new_strides .append (a .stride ()[original_idx ])
@@ -1006,3 +1011,77 @@ def _greater_than_reduce(acc, x):
10061011 new_strides .append (a .stride ()[original_idx ] * a .size ()[original_idx ])
10071012
10081013 return a .as_strided (shape , new_strides , a .storage_offset ())
1014+
1015+
1016+ def patched__broadcast_in_dim_meta_level_2 (
1017+ a : torch ._prims_common .TensorLikeType ,
1018+ shape : torch ._prims_common .ShapeType ,
1019+ broadcast_dimensions : Sequence [int ],
1020+ ):
1021+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
1022+ from torch .fx .experimental .symbolic_shapes import (
1023+ guard_or_false ,
1024+ guard_or_true ,
1025+ sym_or ,
1026+ )
1027+
1028+ # Type checks
1029+ assert isinstance (a , torch ._prims_common .TensorLike )
1030+ assert isinstance (shape , Sequence )
1031+ assert isinstance (broadcast_dimensions , Sequence )
1032+
1033+ # every dimension must be accounted for
1034+ assert a .ndim == len (broadcast_dimensions )
1035+
1036+ # broadcast shape must have weakly more dimensions
1037+ assert len (shape ) >= a .ndim
1038+
1039+ # broadcast_dimensions must be an ascending sequence
1040+ # (no relative reordering of dims) of integers and
1041+ # each dimension must be within the new shape
1042+ def _greater_than_reduce (acc , x ):
1043+ assert isinstance (x , (int , torch .export .Dim )), f"unexpected type { type (x )} for x"
1044+ assert x > acc
1045+ assert x < len (shape )
1046+
1047+ return x
1048+
1049+ reduce (_greater_than_reduce , broadcast_dimensions , - 1 )
1050+
1051+ # shape must be broadcastable to
1052+ for idx , new_idx in enumerate (broadcast_dimensions ):
1053+ torch ._check (
1054+ sym_or (a .shape [idx ] == 1 , shape [new_idx ] == a .shape [idx ]),
1055+ lambda idx = idx , new_idx = new_idx : (
1056+ f"{ a .shape [idx ]} must be broadcastable to { shape [new_idx ]} "
1057+ ),
1058+ )
1059+
1060+ new_strides = []
1061+ original_idx = 0
1062+ for idx in range (len (shape )):
1063+ if idx in broadcast_dimensions :
1064+ # Assigns a stride of zero to dimensions
1065+ # which were actually broadcast
1066+ if guard_or_false (a .shape [original_idx ] == 1 ):
1067+ if guard_or_false (a .shape [original_idx ] == shape [idx ]):
1068+ new_strides .append (a .stride ()[original_idx ])
1069+ else :
1070+ new_strides .append (0 )
1071+ # PATCHED: disabled this check
1072+ elif guard_or_false (a .shape [original_idx ] != 1 ):
1073+ new_strides .append (a .stride ()[original_idx ])
1074+ else :
1075+ # PATCHED: torch._check was removed
1076+ new_strides .append (a .stride ()[original_idx ])
1077+ original_idx = original_idx + 1
1078+ else :
1079+ if guard_or_true (shape [idx ] != 1 ):
1080+ # consistent with previous use of guard_size_oblivious
1081+ new_strides .append (0 )
1082+ elif original_idx == a .ndim :
1083+ new_strides .append (1 )
1084+ else :
1085+ new_strides .append (a .stride ()[original_idx ] * a .size ()[original_idx ])
1086+
1087+ return a .as_strided (shape , new_strides , a .storage_offset ())
0 commit comments