Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-xpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py"
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"

- name: Run E2E Bfloat16 tests
timeout-minutes: 20
Expand Down
31 changes: 20 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19.2)
project(sgl_kernel)

set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)

# Torch
find_package(Python3 COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
Expand All @@ -27,20 +27,29 @@ include(${SGL_OPS_XPU_ROOT}/cmake/BuildFlags.cmake)

include(FetchContent)

# # cutlass
# FetchContent_Declare(
# repo-cutlass-sycl
# GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git
# GIT_TAG ef9797f4327886ad231bfe853099ca022060c293
# GIT_SHALLOW OFF
# )
# FetchContent_Populate(repo-cutlass-sycl)
# SYCL support in cutlass
add_compile_definitions(CUTLASS_ENABLE_SYCL)
add_compile_definitions(SYCL_INTEL_TARGET)
set(CUTLASS_ENABLE_SYCL ON CACHE BOOL "Enable SYCL in the cutlass" FORCE)
set(CUTLASS_ENABLE_BENCHMARKS OFF CACHE BOOL "Remove benchmark to avoid cmake version issue in google benchmark" FORCE)
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutlass" FORCE)

# cutlass
FetchContent_Declare(
repo-cutlass-sycl
GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git
GIT_TAG ab1f4b8ddfd5748e4c00317710cdbcecda58de28
GIT_SHALLOW OFF
)
FetchContent_MakeAvailable(repo-cutlass-sycl)


include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/src
# ${repo-cutlass-sycl_SOURCE_DIR}/include
# ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
${repo-cutlass-sycl_SOURCE_DIR}/include
${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
${repo-cutlass-sycl_SOURCE_DIR}/applications
)

add_subdirectory(${SGL_OPS_XPU_ROOT}/src)
7 changes: 5 additions & 2 deletions cmake/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ endfunction()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# # -- Host flags (SYCL_CXX_FLAGS)
list(APPEND SYCL_HOST_FLAGS -fPIC)
list(APPEND SYCL_HOST_FLAGS -std=c++17)
list(APPEND SYCL_HOST_FLAGS -std=c++20)
# SYCL headers warnings
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated-declarations)
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated)
Expand Down Expand Up @@ -71,6 +71,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-approx-func)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Wno-absolute-value)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier)

if(CMAKE_BUILD_TYPE MATCHES Debug)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization)
Expand Down Expand Up @@ -116,7 +119,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(AOT_TARGETS STREQUAL "none")
set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE)
else()
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen,spir64)
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
Expand Down
25 changes: 10 additions & 15 deletions include/sgl_flash_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,21 @@ std::vector<at::Tensor> mha_fwd(
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we changing function signature ?

std::optional<const at::Tensor>& num_pages_per_seq_, // (b_k, )
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
bool is_causal,
int window_size_left,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[build-system]
requires = [
"scikit-build-core>=0.10",
"pytorch-triton-xpu @ https://download.pytorch.org/whl/test/pytorch_triton_xpu-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
"wheel",
]
build-backend = "scikit_build_core.build"
Expand Down Expand Up @@ -32,6 +31,7 @@ exclude = [

[tool.scikit-build]
cmake.build-type = "Release"
build-dir = "build"
minimum-version = "build-system.requires"

wheel.py-api = "cp39"
Expand Down
79 changes: 60 additions & 19 deletions python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import torch
import torch.nn as nn

try:
from sgl_kernel import flash_ops
except:
raise ImportError("Can not import sgl_kernel. Please check your installation.")


def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI
Expand All @@ -18,10 +13,16 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
if torch.cuda.is_available():
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
elif torch.xpu.is_available():
device_name = torch.xpu.get_device_properties(0).name
return "B580" in device_name or "e211" in device_name
else:
return False


def maybe_contiguous(x):
Expand Down Expand Up @@ -171,22 +172,51 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)

if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new
cu_seqlens_k_new = torch.arange(
0, k.size(0) + 1, dtype=torch.int, device=k.device
)
elif k is None:
cu_seqlens_k_new = torch.zeros_like(
cu_seqlens_q, dtype=torch.int32, device=q.device
)
if cache_seqlens is not None:
max_seqlen_k = cache_seqlens.max().item()
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
max_page_size_per_seq = page_table.size(1)
num_pages_per_seq = torch.arange(
0,
cache_seqlens.size(0) * max_page_size_per_seq,
max_page_size_per_seq,
device=cache_seqlens.device,
).to(torch.int32)
cu_seqlens_k = torch.concat(
(
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
torch.cumsum(cache_seqlens, 0),
)
).to(torch.int32)

Comment on lines 174 to 189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these ops are causing perf degrade compared to triton

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries, we are aware of this. this PR still needs a lot of change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't have to pay too much attention for it right now, will be fixed later.

out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k,
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
max_seqlen_k,
page_table,
num_pages_per_seq,
cache_batch_idx,
cache_leftpad,
rotary_cos,
Expand Down Expand Up @@ -235,13 +265,26 @@ def flash_attn_varlen_func(
):
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
"flash_attn at sgl-kernel-xpu is only supported on BMG and later"
)

if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
batch_size = cu_seqlens_q.numel() - 1
page_table = (
torch.arange(0, batch_size, device=q.device)
.to(torch.int32)
.reshape([batch_size, 1])
.contiguous()
)
Comment on lines +255 to +267
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what extra functionality we are trying to provide ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current kernel implementation are align between vllm and sglang requests, so there will be some changes on the sglang side.”


out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
Expand All @@ -250,15 +293,13 @@ def flash_attn_varlen_func(
None, # k_new
None, # v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, # page_table,
page_table, # page_table,
page_table, # num_pages_per_seq
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
Expand Down
Loading
Loading