Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-xpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py"
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py && python3 -m pytest -v -s test_flash_attention.py"

- name: Run E2E Bfloat16 tests
timeout-minutes: 20
Expand Down
31 changes: 20 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19.2)
project(sgl_kernel)

set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)

# Torch
find_package(Python3 COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
Expand All @@ -27,20 +27,29 @@ include(${SGL_OPS_XPU_ROOT}/cmake/BuildFlags.cmake)

include(FetchContent)

# # cutlass
# FetchContent_Declare(
# repo-cutlass-sycl
# GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git
# GIT_TAG ef9797f4327886ad231bfe853099ca022060c293
# GIT_SHALLOW OFF
# )
# FetchContent_Populate(repo-cutlass-sycl)
# SYCL support in cutlass
add_compile_definitions(CUTLASS_ENABLE_SYCL)
add_compile_definitions(SYCL_INTEL_TARGET)
set(CUTLASS_ENABLE_SYCL ON CACHE BOOL "Enable SYCL in the cutlass" FORCE)
set(CUTLASS_ENABLE_BENCHMARKS OFF CACHE BOOL "Remove benchmark to avoid cmake version issue in google benchmark" FORCE)
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutlass" FORCE)

# cutlass
FetchContent_Declare(
repo-cutlass-sycl
GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git
GIT_TAG ab1f4b8ddfd5748e4c00317710cdbcecda58de28
GIT_SHALLOW OFF
)
FetchContent_MakeAvailable(repo-cutlass-sycl)


include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/src
# ${repo-cutlass-sycl_SOURCE_DIR}/include
# ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
${repo-cutlass-sycl_SOURCE_DIR}/include
${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
${repo-cutlass-sycl_SOURCE_DIR}/applications
)

add_subdirectory(${SGL_OPS_XPU_ROOT}/src)
9 changes: 7 additions & 2 deletions cmake/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ endfunction()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# # -- Host flags (SYCL_CXX_FLAGS)
list(APPEND SYCL_HOST_FLAGS -fPIC)
list(APPEND SYCL_HOST_FLAGS -std=c++17)
list(APPEND SYCL_HOST_FLAGS -std=c++20)
# SYCL headers warnings
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated-declarations)
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated)
Expand Down Expand Up @@ -71,6 +71,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-approx-func)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Wno-absolute-value)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier)

if(CMAKE_BUILD_TYPE MATCHES Debug)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization)
Expand Down Expand Up @@ -116,7 +119,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(AOT_TARGETS STREQUAL "none")
set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE)
else()
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen,spir64)
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
Expand All @@ -126,6 +129,8 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")

set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS})

# set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} -fno-sycl-instrument-device-code)
# set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} ${SYCL_LINK_FLAGS})
set(SYCL_OFFLINE_COMPILER_FLAGS "${SYCL_OFFLINE_COMPILER_AOT_OPTIONS}${SYCL_OFFLINE_COMPILER_CG_OPTIONS}")
else()
message("Not compiling with XPU. Currently only support GCC compiler on Linux as CXX compiler.")
Expand Down
7 changes: 1 addition & 6 deletions include/sgl_flash_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,13 @@ std::vector<at::Tensor> mha_fwd(
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& num_pages_, // (b_k, )
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[build-system]
requires = [
"scikit-build-core>=0.10",
"pytorch-triton-xpu @ https://download.pytorch.org/whl/test/pytorch_triton_xpu-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
"wheel",
]
build-backend = "scikit_build_core.build"
Expand Down Expand Up @@ -32,6 +31,7 @@ exclude = [

[tool.scikit-build]
cmake.build-type = "Release"
build-dir = "build"
minimum-version = "build-system.requires"

wheel.py-api = "cp39"
Expand Down
55 changes: 41 additions & 14 deletions python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import torch
import torch.nn as nn

try:
from sgl_kernel import flash_ops
except:
raise ImportError("Can not import sgl_kernel. Please check your installation.")


def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI
Expand All @@ -18,10 +13,16 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
if torch.cuda.is_available():
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
elif torch.xpu.is_available():
device_name = torch.xpu.get_device_properties(0).name
return "B580" in device_name or "e211" in device_name
else:
return False


def maybe_contiguous(x):
Expand Down Expand Up @@ -171,22 +172,48 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)

if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new
cu_seqlens_k_new = torch.arange(
0, k.size(0) + 1, dtype=torch.int, device=k.device
)
elif k is None:
cu_seqlens_k_new = torch.zeros_like(
cu_seqlens_q, dtype=torch.int32, device=q.device
)
if cache_seqlens is not None:
max_seqlen_k = cache_seqlens.max().item()
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
max_page_size_per_seq = page_table.shape(1)
num_pages_per_seq = torch.arange(0, cache_seqlens.size(0) * max_page_size_per_seq, max_page_size_per_seq, device=cache_seqlens.device).to(torch.int32)
cu_seqlens_k = torch.concat(
(
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
torch.cumsum(cache_seqlens, 0),
)
).to(torch.int32)

import pdb; pdb.set_trace()

out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k,
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
max_seqlen_k,
page_table,
num_pages_per_seq,
cache_batch_idx,
cache_leftpad,
rotary_cos,
Expand Down
Loading
Loading