Skip to content

Commit 39c7bd0

Browse files
committed
align the interface
1 parent e8dd142 commit 39c7bd0

File tree

4 files changed

+25
-26
lines changed

4 files changed

+25
-26
lines changed

python/sgl_kernel/flash_attn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import torch
44
import torch.nn as nn
55

6-
try:
7-
from sgl_kernel import flash_ops
8-
except:
9-
raise ImportError("Can not import sgl_kernel. Please check your installation.")
6+
# try:
7+
# from sgl_kernel import flash_ops
8+
# except:
9+
# raise ImportError("Can not import sgl_kernel. Please check your installation.")
1010

1111

1212
def is_fa3_supported(device=None) -> bool:
@@ -18,10 +18,16 @@ def is_fa3_supported(device=None) -> bool:
1818
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
1919
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
2020
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
21-
return (
21+
if torch.cuda.is_available():
22+
return (
2223
torch.cuda.get_device_capability(device)[0] == 9
2324
or torch.cuda.get_device_capability(device)[0] == 8
24-
) and (torch.version.cuda >= "12.3")
25+
) and (torch.version.cuda >= "12.3")
26+
elif torch.xpu.is_available():
27+
device_name = torch.xpu.get_device_properties(0).name
28+
return "B580" in device_name or "e211" in device_name
29+
else:
30+
return False
2531

2632

2733
def maybe_contiguous(x):

src/sycl/chunked_prefill.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ struct Flash_fwd_params {
133133

134134
// Local window size
135135
int window_size_left, window_size_right;
136-
int attention_chunk;
137136

138137
// Pointer to the RNG seed (idx 0) and offset (idx 1).
139138
uint64_t* rng_state;
@@ -541,14 +540,13 @@ std::vector<at::Tensor> mha_fwd(
541540
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
542541
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
543542
std::optional<const at::Tensor>& seqlens_rotary_, // b
544-
// std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
545-
// std::optional<at::Tensor> &k_descale_, // (b, h_k)
546-
// std::optional<at::Tensor> &v_descale_, // (b, h_k)
547-
std::optional<double> softmax_scale_,
543+
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
544+
std::optional<at::Tensor>& k_descale_, // (b, h_k)
545+
std::optional<at::Tensor>& v_descale_, // (b, h_k)
546+
const float softmax_scale_,
548547
bool is_causal,
549548
int window_size_left,
550549
int window_size_right,
551-
int attention_chunk,
552550
float const softcap,
553551
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
554552
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
@@ -619,10 +617,8 @@ std::vector<at::Tensor> mha_fwd(
619617
int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
620618
int const num_heads_k = k.size(-2);
621619
int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
622-
double softmax_scale = 1.0 / sqrt(double(head_size));
623-
if (softmax_scale_.has_value()) {
624-
softmax_scale = softmax_scale_.value();
625-
}
620+
float softmax_scale = softmax_scale_;
621+
626622
if (!kv_batch_idx_.has_value()) {
627623
TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
628624
}
@@ -791,8 +787,8 @@ std::vector<at::Tensor> mha_fwd(
791787

792788
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
793789
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
794-
params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
795-
params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
790+
params.is_causal = window_size_left < 0 && window_size_right == 0;
791+
params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
796792

797793
// TODO: check this
798794
if (window_size_left < 0) {
@@ -801,13 +797,8 @@ std::vector<at::Tensor> mha_fwd(
801797
if (window_size_right < 0) {
802798
window_size_right = seqlen_q - 1;
803799
}
804-
if (attention_chunk > 0) {
805-
window_size_left = std::min(window_size_left, attention_chunk - 1);
806-
window_size_right = std::min(window_size_right, attention_chunk - 1);
807-
}
808800
params.window_size_left = window_size_left;
809801
params.window_size_right = window_size_right;
810-
params.attention_chunk = attention_chunk;
811802

812803
params.total_q = total_q;
813804
params.total_k = total_k;

src/torch_extension_sycl.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ limitations under the License.
1616
#include <torch/all.h>
1717
#include <torch/library.h>
1818

19-
#include "sgl_kernel_torch_shim.h"
20-
2119
#include "sgl_flash_kernel_ops.h"
2220
#include "sgl_kernel_ops.h"
21+
#include "sgl_kernel_torch_shim.h"
2322

2423
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
2524
/*

tests/test_flash_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
def is_hopper():
1919
# Only Hopper supports different V headdim
20-
return torch.cuda.get_device_properties(0).major >= 9
20+
if torch.cuda.is_available():
21+
return torch.cuda.get_device_properties(0).major >= 9
22+
else:
23+
return False
2124

2225

2326
def is_fa3_supported(device=None) -> bool:

0 commit comments

Comments
 (0)