@@ -31,16 +31,30 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
3131 def vector_mask_function (
3232 * args , mask_function = mask_function , dimensions = dimensions , indices = indices
3333 ):
34- assert len (args ) == len (dimensions ) == len (udimensions ) + 1 , (
34+ assert len (args ) == len (dimensions ) == len (udimensions ), (
3535 f"Mismatch between args={ string_type (args )} and dimensions={ dimensions } "
36- f"and udimensions={ udimensions } "
36+ f"and udimensions={ udimensions } . "
3737 )
38+ assert len (indices ) == len (args ), (
39+ f"Mismatch between args={ string_type (args )} and indices={ indices } , "
40+ f"they should have the same length."
41+ )
42+ for a in args :
43+ assert (
44+ a .ndim == 1
45+ ), f"Expected a tensor with 1 dimension not { string_type (a , with_shape = True )} "
46+ torch ._check (a .shape [0 ] > 0 )
47+
3848 # new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
3949 new_args = [
40- a .reshape (( - 1 ,)). unsqueeze (shape [0 ]).unsqueeze (shape [1 ]).unsqueeze (shape [2 ])
41- for a , shape in zip (args , udimensions )
50+ a .unsqueeze (dims [0 ]).unsqueeze (dims [1 ]).unsqueeze (dims [2 ])
51+ for a , dims in zip (args , udimensions )
4252 ]
4353 max_shape = tuple (args [i ].shape [0 ] for i in indices )
54+ if is_torchdynamo_exporting ():
55+ for a in args :
56+ # The exporter should export with a dimension > 1 to make sure it is dynamic.
57+ torch ._check (a .shape [0 ] > 1 )
4458 expanded_args = [a .expand (max_shape ) for a in new_args ]
4559 return mask_function (* expanded_args )
4660
0 commit comments