Skip to content

Commit cc61913

Browse files
committed
Initialize Cutlass-SYCL support (sgl-project#6)
* initialize Cutlass support Add chunked prefill op --------- Co-authored-by: Swift.Sun <[email protected]>
1 parent 63506b0 commit cc61913

17 files changed

+2924
-100
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
timeout-minutes: 20
5151
run: |
5252
docker exec -w /root/sglang ci_sglang_xpu \
53-
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py"
53+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"
5454
5555
- name: Run E2E Bfloat16 tests
5656
timeout-minutes: 20

CMakeLists.txt

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19.2)
22
project(sgl_kernel)
33

44
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
5-
set(CMAKE_CXX_STANDARD 17)
5+
set(CMAKE_CXX_STANDARD 20)
66

77
# Torch
88
find_package(Python3 COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
@@ -27,20 +27,28 @@ include(${SGL_OPS_XPU_ROOT}/cmake/BuildFlags.cmake)
2727

2828
include(FetchContent)
2929

30-
# # cutlass
31-
# FetchContent_Declare(
32-
# repo-cutlass-sycl
33-
# GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git
34-
# GIT_TAG ef9797f4327886ad231bfe853099ca022060c293
35-
# GIT_SHALLOW OFF
36-
# )
37-
# FetchContent_Populate(repo-cutlass-sycl)
30+
# SYCL support in cutlass
31+
add_compile_definitions(CUTLASS_ENABLE_SYCL)
32+
add_compile_definitions(SYCL_INTEL_TARGET)
33+
set(CUTLASS_ENABLE_SYCL ON CACHE BOOL "Enable SYCL in the cutlass" FORCE)
34+
set(CUTLASS_ENABLE_BENCHMARKS OFF CACHE BOOL "Remove benchmark to avoid cmake version issue in google benchmark" FORCE)
35+
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutlass" FORCE)
36+
37+
# cutlass
38+
FetchContent_Declare(
39+
repo-cutlass-sycl
40+
GIT_REPOSITORY https://github.com/intel/sycl-tla.git
41+
GIT_TAG 8cdf47660e5c64c0f2191b11525a87bc76d71d9a
42+
GIT_SHALLOW OFF
43+
)
44+
FetchContent_MakeAvailable(repo-cutlass-sycl)
45+
3846

3947
include_directories(
4048
${CMAKE_CURRENT_SOURCE_DIR}/include
4149
${CMAKE_CURRENT_SOURCE_DIR}/src
42-
# ${repo-cutlass-sycl_SOURCE_DIR}/include
43-
# ${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
50+
${repo-cutlass-sycl_SOURCE_DIR}/include
51+
${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include
4452
)
4553

4654
add_subdirectory(${SGL_OPS_XPU_ROOT}/src)

cmake/BuildFlags.cmake

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ endfunction()
2626
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
2727
# # -- Host flags (SYCL_CXX_FLAGS)
2828
list(APPEND SYCL_HOST_FLAGS -fPIC)
29-
list(APPEND SYCL_HOST_FLAGS -std=c++17)
29+
list(APPEND SYCL_HOST_FLAGS -std=c++20)
3030
# SYCL headers warnings
3131
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated-declarations)
3232
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated)
@@ -71,6 +71,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
7171
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-approx-func)
7272
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Wno-absolute-value)
7373
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz)
74+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code)
75+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator)
76+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier) #,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate)
7477

7578
if(CMAKE_BUILD_TYPE MATCHES Debug)
7679
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization)
@@ -110,18 +113,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
110113

111114

112115
set(AOT_TARGETS "bmg")
113-
if(TORCH_XPU_ARCH_LIST)
114-
set(AOT_TARGETS "${TORCH_XPU_ARCH_LIST}")
115-
endif()
116-
if(AOT_TARGETS STREQUAL "none")
117-
set(TORCH_XPU_ARCH_LIST "" PARENT_SCOPE)
118-
else()
119-
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen,spir64)
120-
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
121-
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
122-
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
123-
set(TORCH_XPU_ARCH_LIST ${AOT_TARGETS} PARENT_SCOPE)
124-
endif()
116+
set(SYCL_TARGETS_OPTION -fsycl-targets=spir64_gen)
117+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} ${SYCL_TARGETS_OPTION})
118+
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS} ${SYCL_TARGETS_OPTION})
119+
set(SYCL_OFFLINE_COMPILER_AOT_OPTIONS "-device ${AOT_TARGETS}")
125120
message(STATUS "Compile Intel GPU AOT Targets for ${AOT_TARGETS}")
126121

127122
set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS})

include/sgl_flash_kernel_ops.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,10 @@ std::vector<at::Tensor> mha_fwd(
4848
// h_k, d) if there is page_table.
4949
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
5050
// page_size, h_k, dv) if there is page_table.
51-
std::optional<const at::Tensor>&
52-
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
53-
std::optional<const at::Tensor>&
54-
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
5551
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
56-
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
5752
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
5853
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
59-
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
60-
std::optional<const at::Tensor>&
61-
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
62-
std::optional<const at::Tensor>&
63-
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
6454
std::optional<int> max_seqlen_q_,
65-
// TODO: check if we need max_seqlen_k
6655
std::optional<int> max_seqlen_k_,
6756
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
6857
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[build-system]
22
requires = [
33
"scikit-build-core>=0.10",
4-
"pytorch-triton-xpu @ https://download.pytorch.org/whl/test/pytorch_triton_xpu-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
54
"wheel",
65
]
76
build-backend = "scikit_build_core.build"
@@ -32,6 +31,7 @@ exclude = [
3231

3332
[tool.scikit-build]
3433
cmake.build-type = "Release"
34+
build-dir = "build"
3535
minimum-version = "build-system.requires"
3636

3737
wheel.py-api = "cp39"

python/sgl_kernel/flash_attn.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
import torch
44
import torch.nn as nn
55

6-
try:
7-
from sgl_kernel import flash_ops
8-
except:
9-
raise ImportError("Can not import sgl_kernel. Please check your installation.")
10-
116

127
def is_fa3_supported(device=None) -> bool:
138
# There some fa3 FYI
@@ -18,10 +13,15 @@ def is_fa3_supported(device=None) -> bool:
1813
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
1914
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
2015
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
21-
return (
22-
torch.cuda.get_device_capability(device)[0] == 9
23-
or torch.cuda.get_device_capability(device)[0] == 8
24-
) and (torch.version.cuda >= "12.3")
16+
if torch.cuda.is_available():
17+
return (
18+
torch.cuda.get_device_capability(device)[0] == 9
19+
or torch.cuda.get_device_capability(device)[0] == 8
20+
) and (torch.version.cuda >= "12.3")
21+
elif torch.xpu.is_available():
22+
return torch.xpu.get_device_properties().has_fp64
23+
else:
24+
return False
2525

2626

2727
def maybe_contiguous(x):
@@ -171,21 +171,31 @@ def flash_attn_with_kvcache(
171171
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
172172
rotary_seqlens = maybe_contiguous(rotary_seqlens)
173173

174+
if cu_seqlens_q == None: # !is_varlen_q
175+
cu_seqlens_q = torch.arange(
176+
0, q.size(0) + 1, dtype=torch.int, device=q.device
177+
) * q.size(1)
178+
max_seqlen_q = q.size(1)
179+
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
180+
if cache_seqlens is not None:
181+
max_seqlen_k = cache_seqlens.max().item()
182+
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
183+
cu_seqlens_k = torch.concat(
184+
(
185+
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
186+
torch.cumsum(cache_seqlens, 0),
187+
)
188+
).to(torch.int32)
189+
174190
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
175191
q,
176192
k_cache,
177193
v_cache,
178-
k,
179-
v,
180194
qv,
181-
None, # out
182195
cu_seqlens_q,
183-
None, # cu_seqlens_k
184-
cu_seqlens_k_new,
185-
None, # seqused_q
186-
cache_seqlens,
196+
cu_seqlens_k,
187197
max_seqlen_q,
188-
None, # max_seqlen_k
198+
max_seqlen_k,
189199
page_table,
190200
cache_batch_idx,
191201
cache_leftpad,
@@ -235,13 +245,26 @@ def flash_attn_varlen_func(
235245
):
236246
if not is_fa3_supported():
237247
raise NotImplementedError(
238-
"flash_attn at sgl-kernel is only supported on sm90 and above"
248+
"flash_attn at sgl-kernel-xpu is only supported on BMG and later"
239249
)
240250

241251
if softmax_scale is None:
242252
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
243253
-0.5
244254
)
255+
if cu_seqlens_q == None: # !is_varlen_q
256+
cu_seqlens_q = torch.arange(
257+
0, q.size(0) + 1, dtype=torch.int, device=q.device
258+
) * q.size(1)
259+
max_seqlen_q = q.size(1)
260+
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
261+
batch_size = cu_seqlens_q.numel() - 1
262+
page_table = (
263+
torch.arange(0, batch_size, device=q.device)
264+
.to(torch.int32)
265+
.reshape([batch_size, 1])
266+
.contiguous()
267+
)
245268

246269
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
247270
q,
@@ -250,15 +273,13 @@ def flash_attn_varlen_func(
250273
None, # k_new
251274
None, # v_new
252275
qv, # qv
253-
None, # out
254276
cu_seqlens_q,
255277
cu_seqlens_k,
256278
None, # cu_seqlens_k_new
257-
seqused_q,
258-
seqused_k,
259279
max_seqlen_q,
260280
max_seqlen_k,
261-
None, # page_table,
281+
page_table, # page_table,
282+
page_table, # num_pages_per_seq
262283
None, # kv_batch_idx
263284
None, # leftpad_k
264285
None, # rotary cos

src/sycl/TripleOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ struct op_and_mul_functor {
8888

8989
template <typename T = float>
9090
void get_config(
91-
const Tensor& input,
92-
const Tensor& out,
91+
const at::Tensor& input,
92+
const at::Tensor& out,
9393
int64_t& numel,
9494
int64_t& dim,
9595
int64_t& wg_size,
@@ -111,7 +111,7 @@ void get_config(
111111
}
112112

113113
template <typename T_to = float, typename T_from = float>
114-
void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
114+
void silu_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) {
115115
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
116116
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
117117

@@ -136,7 +136,7 @@ void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
136136
return;
137137
}
138138

139-
void silu_and_mul(Tensor& out, Tensor& input) {
139+
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
140140
input = input.contiguous();
141141
out = out.contiguous();
142142

@@ -152,7 +152,7 @@ void silu_and_mul(Tensor& out, Tensor& input) {
152152
}
153153

154154
template <typename T_to = float, typename T_from = float>
155-
void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
155+
void gelu_tanh_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) {
156156
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
157157
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
158158

@@ -177,7 +177,7 @@ void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
177177
return;
178178
}
179179

180-
void gelu_tanh_and_mul(Tensor& out, Tensor& input) {
180+
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
181181
input = input.contiguous();
182182
out = out.contiguous();
183183

@@ -193,7 +193,7 @@ void gelu_tanh_and_mul(Tensor& out, Tensor& input) {
193193
}
194194

195195
template <typename T_to = float, typename T_from = float>
196-
void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
196+
void gelu_and_mul_sycl(sycl::queue& q, at::Tensor& input, at::Tensor& out) {
197197
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
198198
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
199199

@@ -218,7 +218,7 @@ void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
218218
return;
219219
}
220220

221-
void gelu_and_mul(Tensor& out, Tensor& input) {
221+
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
222222
input = input.contiguous();
223223
out = out.contiguous();
224224

src/sycl/Utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
#define SYCL_MAX_SUB_GROUP_SIZE dpcppMaxSubGroupSize()
99

10-
using namespace at;
10+
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU")
11+
#define CHECK_SHAPE(x, ...) \
12+
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
1114

1215
using DeviceId = at::DeviceIndex;
1316

0 commit comments

Comments
 (0)