Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-xpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py"
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"

- name: Run E2E Bfloat16 tests
timeout-minutes: 20
Expand Down
30 changes: 19 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19.2)
project(sgl_kernel)

set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)

# Torch
find_package(Python3 COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
Expand All @@ -27,20 +27,28 @@ include(${SGL_OPS_XPU_ROOT}/cmake/BuildFlags.cmake)

include(FetchContent)

# # cutlass
# FetchContent_Declare(
# repo-cutlass-sycl
# GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git
# GIT_TAG ef9797f4327886ad231bfe853099ca022060c293
# GIT_SHALLOW OFF
# )
# FetchContent_Populate(repo-cutlass-sycl)
# SYCL support in cutlass
add_compile_definitions(CUTLASS_ENABLE_SYCL)
add_compile_definitions(SYCL_INTEL_TARGET)
set(CUTLASS_ENABLE_SYCL ON CACHE BOOL "Enable SYCL in the cutlass" FORCE)
set(CUTLASS_ENABLE_BENCHMARKS OFF CACHE BOOL "Remove benchmark to avoid cmake version issue in google benchmark" FORCE)
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutlass" FORCE)

# cutlass
FetchContent_Declare(
repo-cutlass-sycl
GIT_REPOSITORY https://github.com/intel/sycl-tla.git
GIT_TAG 8cdf47660e5c64c0f2191b11525a87bc76d71d9a
GIT_SHALLOW OFF
)
FetchContent_MakeAvailable(repo-cutlass-sycl)


include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/src
# ${repo-cutlass-sycl_SOURCE_DIR}/include
# ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
${repo-cutlass-sycl_SOURCE_DIR}/include
${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
)

add_subdirectory(${SGL_OPS_XPU_ROOT}/src)
21 changes: 8 additions & 13 deletions cmake/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ endfunction()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# # -- Host flags (SYCL_CXX_FLAGS)
list(APPEND SYCL_HOST_FLAGS -fPIC)
list(APPEND SYCL_HOST_FLAGS -std=c++17)
list(APPEND SYCL_HOST_FLAGS -std=c++20)
# SYCL headers warnings
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated-declarations)
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated)
Expand Down Expand Up @@ -71,6 +71,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-approx-func)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Wno-absolute-value)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier) #,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate)

if(CMAKE_BUILD_TYPE MATCHES Debug)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization)
Expand Down Expand Up @@ -110,18 +113,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")


set(AOT_TARGETS "bmg")
if(TORCH_XPU_ARCH_LIST)
set(AOT_TARGETS "${TORCH_XPU_ARCH_LIST}")
endif()
if(AOT_TARGETS STREQUAL "none")
set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE)
else()
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen,spir64)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
set(TORCH_XPU_ARCH_LIST ${AOT_TARGETS} PARENT_SCOPE)
endif()
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
message(STATUS "Compile Intel GPU AOT Targets for ${AOT_TARGETS}")

set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS})
Expand Down
11 changes: 0 additions & 11 deletions include/sgl_flash_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,10 @@ std::vector<at::Tensor> mha_fwd(
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[build-system]
requires = [
"scikit-build-core>=0.10",
"pytorch-triton-xpu @ https://download.pytorch.org/whl/test/pytorch_triton_xpu-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
"wheel",
]
build-backend = "scikit_build_core.build"
Expand Down Expand Up @@ -32,6 +31,7 @@ exclude = [

[tool.scikit-build]
cmake.build-type = "Release"
build-dir = "build"
minimum-version = "build-system.requires"

wheel.py-api = "cp39"
Expand Down
65 changes: 43 additions & 22 deletions python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import torch
import torch.nn as nn

try:
from sgl_kernel import flash_ops
except:
raise ImportError("Can not import sgl_kernel. Please check your installation.")


def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI
Expand All @@ -18,10 +13,15 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
if torch.cuda.is_available():
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
elif torch.xpu.is_available():
return torch.xpu.get_device_properties().has_fp64
else:
return False


def maybe_contiguous(x):
Expand Down Expand Up @@ -171,21 +171,31 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)

if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
if cache_seqlens is not None:
max_seqlen_k = cache_seqlens.max().item()
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
cu_seqlens_k = torch.concat(
(
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
torch.cumsum(cache_seqlens, 0),
)
).to(torch.int32)

out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
cu_seqlens_k,
max_seqlen_q,
None, # max_seqlen_k
max_seqlen_k,
page_table,
cache_batch_idx,
cache_leftpad,
Expand Down Expand Up @@ -235,13 +245,26 @@ def flash_attn_varlen_func(
):
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
"flash_attn at sgl-kernel-xpu is only supported on BMG and later"
)

if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
batch_size = cu_seqlens_q.numel() - 1
page_table = (
torch.arange(0, batch_size, device=q.device)
.to(torch.int32)
.reshape([batch_size, 1])
.contiguous()
)
Comment on lines +255 to +267
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what extra functionality we are trying to provide ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current kernel implementation are align between vllm and sglang requests, so there will be some changes on the sglang side.”


out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
Expand All @@ -250,15 +273,13 @@ def flash_attn_varlen_func(
None, # k_new
None, # v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, # page_table,
page_table, # page_table,
page_table, # num_pages_per_seq
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
Expand Down
16 changes: 8 additions & 8 deletions src/sycl/TripleOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ struct op_and_mul_functor {

template <typename T = float>
void get_config(
const Tensor& input,
const Tensor& out,
const at::Tensor& input,
const at::Tensor& out,
int64_t& numel,
int64_t& dim,
int64_t& wg_size,
Expand All @@ -111,7 +111,7 @@ void get_config(
}

template <typename T_to = float, typename T_from = float>
void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
void silu_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) {
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());

Expand All @@ -136,7 +136,7 @@ void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
return;
}

void silu_and_mul(Tensor& out, Tensor& input) {
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
input = input.contiguous();
out = out.contiguous();

Expand All @@ -152,7 +152,7 @@ void silu_and_mul(Tensor& out, Tensor& input) {
}

template <typename T_to = float, typename T_from = float>
void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
void gelu_tanh_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) {
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());

Expand All @@ -177,7 +177,7 @@ void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
return;
}

void gelu_tanh_and_mul(Tensor& out, Tensor& input) {
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
input = input.contiguous();
out = out.contiguous();

Expand All @@ -193,7 +193,7 @@ void gelu_tanh_and_mul(Tensor& out, Tensor& input) {
}

template <typename T_to = float, typename T_from = float>
void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
void gelu_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) {
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());

Expand All @@ -218,7 +218,7 @@ void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
return;
}

void gelu_and_mul(Tensor& out, Tensor& input) {
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
input = input.contiguous();
out = out.contiguous();

Expand Down
5 changes: 4 additions & 1 deletion src/sycl/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

#define SYCL_MAX_SUB_GROUP_SIZE dpcppMaxSubGroupSize()

using namespace at;
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU")
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

using DeviceId = at::DeviceIndex;

Expand Down
Loading