@@ -68,6 +68,7 @@ def test_copy_blocks(
6868 pytest .skip ()
6969 current_platform .seed_everything (seed )
7070 torch .set_default_device (device )
71+ torch .cuda .set_device (device )
7172 # Generate random block mappings where each source block is mapped to two
7273 # destination blocks.
7374 assert 2 * num_mappings <= num_blocks
@@ -152,6 +153,7 @@ def test_reshape_and_cache(
152153 pytest .skip ()
153154 current_platform .seed_everything (seed )
154155 torch .set_default_device (device )
156+ torch .cuda .set_device (device )
155157 # Create a random slot mapping.
156158 num_slots = block_size * num_blocks
157159 slot_mapping_lst = random .sample (range (num_slots ), num_tokens )
@@ -272,6 +274,7 @@ def test_reshape_and_cache_flash(
272274) -> None :
273275 current_platform .seed_everything (seed )
274276 torch .set_default_device (device )
277+ torch .cuda .set_device (device )
275278 assert implementation in ["cuda" , "triton" ]
276279 if implementation == "triton" and kv_cache_layout == "HND" :
277280 pytest .skip ("Triton implementation only supports NHD layout." )
@@ -593,6 +596,7 @@ def test_concat_and_cache_mla(
593596) -> None :
594597 current_platform .seed_everything (seed )
595598 torch .set_default_device (device )
599+ torch .cuda .set_device (device )
596600
597601 total_slots = num_blocks * block_size
598602 slot_mapping_lst = random .sample (range (total_slots ), num_tokens )
@@ -662,11 +666,14 @@ def test_concat_and_cache_ds_mla(
662666 seed : int ,
663667 device : str ,
664668) -> None :
669+ if current_platform .is_rocm ():
670+ pytest .skip ("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm" )
665671 if dtype .itemsize != 2 :
666672 pytest .skip ("ds_mla only supports 16-bit input" )
667673 kv_cache_dtype = "fp8_ds_mla"
668674 current_platform .seed_everything (seed )
669675 torch .set_default_device (device )
676+ torch .cuda .set_device (device )
670677
671678 total_slots = num_blocks * block_size
672679 slot_mapping_lst = random .sample (range (total_slots ), num_tokens )
@@ -779,6 +786,7 @@ def test_copy_blocks_mla(
779786) -> None :
780787 current_platform .seed_everything (seed )
781788 torch .set_default_device (device )
789+ torch .cuda .set_device (device )
782790
783791 entry_size = kv_lora_rank + qk_rope_head_dim
784792
@@ -843,6 +851,7 @@ def test_swap_blocks_mla(
843851) -> None :
844852 current_platform .seed_everything (seed )
845853 torch .set_default_device (device )
854+ torch .cuda .set_device (device )
846855
847856 entry_size = kv_lora_rank + qk_rope_head_dim
848857
0 commit comments