88import torch .nn .functional as F
99from einops import rearrange , repeat
1010
11+ import utils
12+
13+ device = utils .get_device ()
14+
1115apply_rotary_emb = None
1216
1317
@@ -25,11 +29,14 @@ def is_fa3_supported(device=None) -> bool:
2529 # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
2630 # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
2731 # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
28- return (
32+ if torch .cuda .is_available ():
33+ return (
2934 torch .cuda .get_device_capability (device )[0 ] == 9
3035 or torch .cuda .get_device_capability (device )[0 ] == 8
31- ) and (torch .version .cuda >= "12.3" )
32-
36+ ) and (torch .version .cuda >= "12.3" )
37+ elif torch .xpu .is_available ():
38+ device_name = torch .xpu .get_device_properties (0 ).name
39+ return "B580" in device_name or "e211" in device_name
3340
3441DISABLE_BACKWARD = True
3542# For CI test, we close them to True.
@@ -551,7 +558,6 @@ def test_flash_attn_kvcache(
551558 pytest .skip ()
552559 if rotary_fraction == 0.0 and has_rotary_seqlens :
553560 pytest .skip ()
554- device = "cuda"
555561 # set seed
556562 torch .random .manual_seed (0 )
557563 batch_size = 5
@@ -1077,7 +1083,6 @@ def test_flash_attn_varlen_output(
10771083):
10781084 from sgl_kernel .flash_attn import flash_attn_varlen_func
10791085
1080- device = "cuda"
10811086 # set seed
10821087 torch .random .manual_seed (seqlen_q + seqlen_k + d + int (causal ) * 2 + int (local ))
10831088 # batch_size = 40
0 commit comments