@@ -281,7 +281,7 @@ def _check_out(
281281
282282 def run_test (
283283 self ,
284- score_mod : Optional [Callable ],
284+ score_mod : Optional [Callable ] = None ,
285285 dtype : torch .dtype = torch .float16 ,
286286 Q_B : int = B ,
287287 Q_H : int = Hq ,
@@ -348,7 +348,7 @@ def run_test_with_call(
348348 if not golden_call :
349349 golden_call = sdpa_call
350350 q = torch .randn (
351- (Q_B , KV_H , Q_S * ( Q_H // KV_H ) , Q_D ),
351+ (Q_B , KV_H , Q_S , Q_D ),
352352 dtype = dtype ,
353353 device = "cuda" ,
354354 requires_grad = False ,
@@ -850,6 +850,61 @@ def seq_mask_mod(score, b, h, q, kv):
850850 self .run_test (seq_mask_mod , dtype )
851851 self .run_test_with_paged_attention (seq_mask_mod , dtype )
852852
853+ @supported_platform
854+ def test_non_divisible_offset_mask (self ):
855+ KV_S = S - 3
856+ offset_tensor = torch .tensor (S // 2 - 3 , device = "cuda" , dtype = torch .int32 )
857+
858+ def mask_mod (b , h , q , kv ):
859+ return kv >= q + offset_tensor
860+
861+ block_mask = create_block_mask (mask_mod , B , 1 , 1 , KV_S )
862+ self .run_test (KV_S = KV_S , block_mask = block_mask )
863+
864+ @supported_platform
865+ def test_non_divisible_offset_mask_with_captured_buffer (self ):
866+ KV_S = S - 3
867+ offset_kv = torch .randn (KV_S , device = "cuda" , dtype = torch .bfloat16 )
868+ offset_tensor = torch .tensor (S // 2 - 3 , device = "cuda" , dtype = torch .int32 )
869+
870+ def score_mod (score , b , h , q , kv ):
871+ return score + offset_kv [kv ]
872+
873+ def mask_mod (b , h , q , kv ):
874+ return kv >= q + offset_tensor
875+
876+ block_mask = create_block_mask (mask_mod , B , 1 , 1 , KV_S )
877+ self .run_test (KV_S = KV_S , block_mask = block_mask , score_mod = score_mod )
878+
879+ @supported_platform
880+ def test_non_divisible_multi_token_offset_mask (self ):
881+ KV_S = S - 3
882+ Q_S = 3
883+ offset_tensor = torch .tensor (S // 2 - 1 , device = "cuda" , dtype = torch .int32 )
884+
885+ def mask_mod (b , h , q , kv ):
886+ return kv >= q + offset_tensor
887+
888+ block_mask = create_block_mask (mask_mod , B , 1 , Q_S , KV_S )
889+ self .run_test (Q_S = Q_S , KV_S = KV_S , block_mask = block_mask )
890+
891+ @supported_platform
892+ def test_non_divisible_multi_token_offset_mask_with_captured_buffer (self ):
893+ KV_S = S - 3
894+ Q_S = 3
895+ offset_kv = torch .randn (KV_S , device = "cuda" , dtype = torch .bfloat16 )
896+ offset_q = torch .randn (Q_S , device = "cuda" , dtype = torch .bfloat16 )
897+ offset_tensor = torch .tensor (S // 2 - 3 , device = "cuda" , dtype = torch .int32 )
898+
899+ def score_mod (score , b , h , q , kv ):
900+ return score + offset_kv [kv ] + offset_q [q ]
901+
902+ def mask_mod (b , h , q , kv ):
903+ return kv >= q + offset_tensor
904+
905+ block_mask = create_block_mask (mask_mod , B , 1 , Q_S , KV_S )
906+ self .run_test (Q_S = Q_S , KV_S = KV_S , block_mask = block_mask , score_mod = score_mod )
907+
853908 @supported_platform
854909 @common_utils .parametrize ("dtype" , test_dtypes_fast )
855910 def test_load_from_bias_seq_only (self , dtype ):
0 commit comments