Skip to content

Commit ed9931e

Browse files
joyddddpytorchmergebot
authored andcommitted
Add tests for non divisible inputs for flex decoding (pytorch#143214)
Pull Request resolved: pytorch#143214 Approved by: https://github.com/drisspg
1 parent 0e8013f commit ed9931e

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

test/inductor/test_flex_decoding.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def _check_out(
281281

282282
def run_test(
283283
self,
284-
score_mod: Optional[Callable],
284+
score_mod: Optional[Callable] = None,
285285
dtype: torch.dtype = torch.float16,
286286
Q_B: int = B,
287287
Q_H: int = Hq,
@@ -348,7 +348,7 @@ def run_test_with_call(
348348
if not golden_call:
349349
golden_call = sdpa_call
350350
q = torch.randn(
351-
(Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D),
351+
(Q_B, KV_H, Q_S, Q_D),
352352
dtype=dtype,
353353
device="cuda",
354354
requires_grad=False,
@@ -850,6 +850,61 @@ def seq_mask_mod(score, b, h, q, kv):
850850
self.run_test(seq_mask_mod, dtype)
851851
self.run_test_with_paged_attention(seq_mask_mod, dtype)
852852

853+
@supported_platform
854+
def test_non_divisible_offset_mask(self):
855+
KV_S = S - 3
856+
offset_tensor = torch.tensor(S // 2 - 3, device="cuda", dtype=torch.int32)
857+
858+
def mask_mod(b, h, q, kv):
859+
return kv >= q + offset_tensor
860+
861+
block_mask = create_block_mask(mask_mod, B, 1, 1, KV_S)
862+
self.run_test(KV_S=KV_S, block_mask=block_mask)
863+
864+
@supported_platform
865+
def test_non_divisible_offset_mask_with_captured_buffer(self):
866+
KV_S = S - 3
867+
offset_kv = torch.randn(KV_S, device="cuda", dtype=torch.bfloat16)
868+
offset_tensor = torch.tensor(S // 2 - 3, device="cuda", dtype=torch.int32)
869+
870+
def score_mod(score, b, h, q, kv):
871+
return score + offset_kv[kv]
872+
873+
def mask_mod(b, h, q, kv):
874+
return kv >= q + offset_tensor
875+
876+
block_mask = create_block_mask(mask_mod, B, 1, 1, KV_S)
877+
self.run_test(KV_S=KV_S, block_mask=block_mask, score_mod=score_mod)
878+
879+
@supported_platform
880+
def test_non_divisible_multi_token_offset_mask(self):
881+
KV_S = S - 3
882+
Q_S = 3
883+
offset_tensor = torch.tensor(S // 2 - 1, device="cuda", dtype=torch.int32)
884+
885+
def mask_mod(b, h, q, kv):
886+
return kv >= q + offset_tensor
887+
888+
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S)
889+
self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask)
890+
891+
@supported_platform
892+
def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self):
893+
KV_S = S - 3
894+
Q_S = 3
895+
offset_kv = torch.randn(KV_S, device="cuda", dtype=torch.bfloat16)
896+
offset_q = torch.randn(Q_S, device="cuda", dtype=torch.bfloat16)
897+
offset_tensor = torch.tensor(S // 2 - 3, device="cuda", dtype=torch.int32)
898+
899+
def score_mod(score, b, h, q, kv):
900+
return score + offset_kv[kv] + offset_q[q]
901+
902+
def mask_mod(b, h, q, kv):
903+
return kv >= q + offset_tensor
904+
905+
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S)
906+
self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask, score_mod=score_mod)
907+
853908
@supported_platform
854909
@common_utils.parametrize("dtype", test_dtypes_fast)
855910
def test_load_from_bias_seq_only(self, dtype):

0 commit comments

Comments
 (0)