33import torch
44import torch .nn as nn
55
6- try :
7- from sgl_kernel import flash_ops
8- except :
9- raise ImportError ("Can not import sgl_kernel. Please check your installation." )
10-
116
127def is_fa3_supported (device = None ) -> bool :
138 # There some fa3 FYI
@@ -18,10 +13,15 @@ def is_fa3_supported(device=None) -> bool:
1813 # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
1914 # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
2015 # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
21- return (
22- torch .cuda .get_device_capability (device )[0 ] == 9
23- or torch .cuda .get_device_capability (device )[0 ] == 8
24- ) and (torch .version .cuda >= "12.3" )
16+ if torch .cuda .is_available ():
17+ return (
18+ torch .cuda .get_device_capability (device )[0 ] == 9
19+ or torch .cuda .get_device_capability (device )[0 ] == 8
20+ ) and (torch .version .cuda >= "12.3" )
21+ elif torch .xpu .is_available ():
22+ return torch .xpu .get_device_properties ().has_fp64
23+ else :
24+ return False
2525
2626
2727def maybe_contiguous (x ):
@@ -171,21 +171,31 @@ def flash_attn_with_kvcache(
171171 rotary_cos , rotary_sin = [maybe_contiguous (x ) for x in (rotary_cos , rotary_sin )]
172172 rotary_seqlens = maybe_contiguous (rotary_seqlens )
173173
174+ if cu_seqlens_q == None : # !is_varlen_q
175+ cu_seqlens_q = torch .arange (
176+ 0 , q .size (0 ) + 1 , dtype = torch .int , device = q .device
177+ ) * q .size (1 )
178+ max_seqlen_q = q .size (1 )
179+ q = q .view (- 1 , q .size (- 2 ), q .size (- 1 )).contiguous ()
180+ if cache_seqlens is not None :
181+ max_seqlen_k = cache_seqlens .max ().item ()
182+ assert cache_seqlens .size (0 ) + 1 == cu_seqlens_q .size (0 )
183+ cu_seqlens_k = torch .concat (
184+ (
185+ torch .zeros (1 , dtype = torch .int32 , device = cache_seqlens .device ),
186+ torch .cumsum (cache_seqlens , 0 ),
187+ )
188+ ).to (torch .int32 )
189+
174190 out , softmax_lse , * rest = torch .ops .sgl_kernel .fwd .default (
175191 q ,
176192 k_cache ,
177193 v_cache ,
178- k ,
179- v ,
180194 qv ,
181- None , # out
182195 cu_seqlens_q ,
183- None , # cu_seqlens_k
184- cu_seqlens_k_new ,
185- None , # seqused_q
186- cache_seqlens ,
196+ cu_seqlens_k ,
187197 max_seqlen_q ,
188- None , # max_seqlen_k
198+ max_seqlen_k ,
189199 page_table ,
190200 cache_batch_idx ,
191201 cache_leftpad ,
@@ -235,13 +245,26 @@ def flash_attn_varlen_func(
235245):
236246 if not is_fa3_supported ():
237247 raise NotImplementedError (
238- "flash_attn at sgl-kernel is only supported on sm90 and above "
248+ "flash_attn at sgl-kernel-xpu is only supported on BMG and later "
239249 )
240250
241251 if softmax_scale is None :
242252 softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (
243253 - 0.5
244254 )
255+ if cu_seqlens_q == None : # !is_varlen_q
256+ cu_seqlens_q = torch .arange (
257+ 0 , q .size (0 ) + 1 , dtype = torch .int , device = q .device
258+ ) * q .size (1 )
259+ max_seqlen_q = q .size (1 )
260+ q = q .view (- 1 , q .size (- 2 ), q .size (- 1 )).contiguous ()
261+ batch_size = cu_seqlens_q .numel () - 1
262+ page_table = (
263+ torch .arange (0 , batch_size , device = q .device )
264+ .to (torch .int32 )
265+ .reshape ([batch_size , 1 ])
266+ .contiguous ()
267+ )
245268
246269 out , softmax_lse , * rest = torch .ops .sgl_kernel .fwd .default (
247270 q ,
@@ -250,15 +273,13 @@ def flash_attn_varlen_func(
250273 None , # k_new
251274 None , # v_new
252275 qv , # qv
253- None , # out
254276 cu_seqlens_q ,
255277 cu_seqlens_k ,
256278 None , # cu_seqlens_k_new
257- seqused_q ,
258- seqused_k ,
259279 max_seqlen_q ,
260280 max_seqlen_k ,
261- None , # page_table,
281+ page_table , # page_table,
282+ page_table , # num_pages_per_seq
262283 None , # kv_batch_idx
263284 None , # leftpad_k
264285 None , # rotary cos
0 commit comments