Skip to content

Commit 711241c

Browse files
rasmithRandall Smith
andauthored
[CI/Build] Fix illegal memory access and unsupported test in kernels/attention/test_cache.py (#29118)
Signed-off-by: Randall Smith <[email protected]> Co-authored-by: Randall Smith <[email protected]>
1 parent d7219bc commit 711241c

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tests/kernels/attention/test_cache.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_copy_blocks(
6868
pytest.skip()
6969
current_platform.seed_everything(seed)
7070
torch.set_default_device(device)
71+
torch.cuda.set_device(device)
7172
# Generate random block mappings where each source block is mapped to two
7273
# destination blocks.
7374
assert 2 * num_mappings <= num_blocks
@@ -152,6 +153,7 @@ def test_reshape_and_cache(
152153
pytest.skip()
153154
current_platform.seed_everything(seed)
154155
torch.set_default_device(device)
156+
torch.cuda.set_device(device)
155157
# Create a random slot mapping.
156158
num_slots = block_size * num_blocks
157159
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
@@ -272,6 +274,7 @@ def test_reshape_and_cache_flash(
272274
) -> None:
273275
current_platform.seed_everything(seed)
274276
torch.set_default_device(device)
277+
torch.cuda.set_device(device)
275278
assert implementation in ["cuda", "triton"]
276279
if implementation == "triton" and kv_cache_layout == "HND":
277280
pytest.skip("Triton implementation only supports NHD layout.")
@@ -593,6 +596,7 @@ def test_concat_and_cache_mla(
593596
) -> None:
594597
current_platform.seed_everything(seed)
595598
torch.set_default_device(device)
599+
torch.cuda.set_device(device)
596600

597601
total_slots = num_blocks * block_size
598602
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
@@ -662,11 +666,14 @@ def test_concat_and_cache_ds_mla(
662666
seed: int,
663667
device: str,
664668
) -> None:
669+
if current_platform.is_rocm():
670+
pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
665671
if dtype.itemsize != 2:
666672
pytest.skip("ds_mla only supports 16-bit input")
667673
kv_cache_dtype = "fp8_ds_mla"
668674
current_platform.seed_everything(seed)
669675
torch.set_default_device(device)
676+
torch.cuda.set_device(device)
670677

671678
total_slots = num_blocks * block_size
672679
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
@@ -779,6 +786,7 @@ def test_copy_blocks_mla(
779786
) -> None:
780787
current_platform.seed_everything(seed)
781788
torch.set_default_device(device)
789+
torch.cuda.set_device(device)
782790

783791
entry_size = kv_lora_rank + qk_rope_head_dim
784792

@@ -843,6 +851,7 @@ def test_swap_blocks_mla(
843851
) -> None:
844852
current_platform.seed_everything(seed)
845853
torch.set_default_device(device)
854+
torch.cuda.set_device(device)
846855

847856
entry_size = kv_lora_rank + qk_rope_head_dim
848857

0 commit comments

Comments
 (0)