Skip to content

Kunshang/flash attn interface #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21")
set(TORCH_SUPPORTED_VERSION_XPU "2.8.0")

set(ENABLE_MOE_KERNEL OFF)
set(FA2_ENABLED ON)

#
# Try to find python package with an executable that exactly matches
Expand Down Expand Up @@ -146,16 +147,65 @@ endif()

if(VLLM_GPU_LANG STREQUAL "SYCL")
set(VLLM_EXT_SRC
"csrc/xpu/cutlass_sycl_demo.cpp"
"csrc/xpu/layernorm.cpp"
"csrc/xpu/torch_bindings.cpp"
)
include_directories("/usr/include")
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
set(CMAKE_CXX_COMPILER icpx)
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/)
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/)
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/syclcompat/)
message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}")
set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl)
list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" )
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64")
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier")
list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" )


# add cutlass dependency
set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "chunk_prefill_BMG" CACHE STRING "CUTLASS revision to use")

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
FetchContent_Declare(
cutlass-sycl
GIT_REPOSITORY https://github.com/sunjiweiswift/cutlass-sycl.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG ${CUTLASS_REVISION}
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
)

# cutlass compilation flags
set(CUTLASS_ENABLE_SYCL "ON")
# set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures")
set(CMAKE_EXPORT_COMPILE_COMMANDS "ON")
set(CUTLASS_ENABLE_BENCHMARKS "OFF")
# disable cuda
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA")
# list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " )
# list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " )


FetchContent_MakeAvailable(cutlass-sycl)
set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")
set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "")
message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}")

# header only library
list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL")
list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET")
list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED")
list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0")
list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always")

endif()

message(STATUS "Enabling C extension.")
Expand All @@ -169,9 +219,47 @@ define_gpu_extension_target(
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

#
# flash attention _C extension
#

if (FA2_ENABLED)
message(STATUS "Enabling fa2 extension.")
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp")

# list(APPEND VLLM_GPU_FLAGS "-ze-opt-large-register-file")
list(APPEND VLLM_GPU_FLAGS "-gline-tables-only")
list(APPEND VLLM_GPU_FLAGS "-O3")
list(APPEND VLLM_GPU_FLAGS "-DNDEBUG")

define_gpu_extension_target(
_vllm_fa2_C
DESTINATION vllm_xpu_kernels
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
csrc/flash_attn/flash_api.cpp
${FA2_GEN_SRCS}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
LINK_FLAGS ${VLLM_GPU_LINK_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

# target_include_directories(_vllm_fa2_C PRIVATE
# csrc/flash_attn
# csrc/flash_attn/src)
endif ()


#
# _moe_C extension
#
Expand All @@ -192,6 +280,7 @@ if (ENABLE_MOE_KERNEL)
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
endif()
Expand Down
6 changes: 6 additions & 0 deletions cmake/toolchain.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# use this file to set the compiler and flags for SYCL

set(CMPLR_ROOT $ENV{CMPLR_ROOT})
message(STATUS "CMPLR_ROOT: ${CMPLR_ROOT}")
set(CMAKE_CXX_COMPILER ${CMPLR_ROOT}/bin/icpx)
set(CMAKE_C_COMPILER ${CMPLR_ROOT}/bin/icx)
3 changes: 2 additions & 1 deletion csrc/core/registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
STRINGIFY(NAME), nullptr, 0, nullptr, \
nullptr, nullptr, nullptr, nullptr}; \
return PyModule_Create(&module); \
}
82 changes: 82 additions & 0 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "pytorch_shim.h"

#include "core/registration.h"
#include "xpu/cutlass_kernels/chunk_prefill.hpp"
#include <torch/all.h>

namespace FLASH_NAMESPACE {

std::vector<at::Tensor> mha_varlen_fwd(
const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k :=
// \sum_{i=0}^{b} s_i or num_blocks x page_block_size
// x num_heads_k x head_size if there's a block_table.
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k :=
// \sum_{i=0}^{b} s_i or num_blocks x page_block_size
// x num_heads_k x head_size if there's a block_table.
std::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
std::optional<at::Tensor>& seqused_k, // b. If given, only this many elements of each batch
// element's keys are used.
std::optional<const at::Tensor>& leftpad_k_, // batch_size
at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
int max_seqlen_k,
float p_dropout,
float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax, std::optional<at::Generator> gen_) {
at::Tensor out;
if(out_.has_value()) {
out = *out_;
}
else {
out = torch::zeros_like(q);
}

cutlass_chunk_prefill_impl(
q,
k,
v,
out,
block_table_,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
is_causal);

if(return_softmax) {
auto softmax_lse = torch::empty_like(out);
return {out, softmax_lse};
}
else {
at::Tensor softmax_lse;
return {out, softmax_lse};
}
}
} // namespace FLASH_NAMESPACE

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"varlen_fwd(Tensor q, Tensor k, Tensor v, Tensor!? out, Tensor "
"cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor "
"block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float "
"softmax_scale, bool zero_tensors, "
"bool is_causal, int window_size_left, int window_size_right, float "
"softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
ops.impl("varlen_fwd", torch::kXPU,
make_pytorch_shim(&FLASH_NAMESPACE::mha_varlen_fwd));
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
110 changes: 110 additions & 0 deletions csrc/flash_attn/pytorch_shim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#pragma once

#include <torch/library.h>

/**
* Unfortunately, the type signatures of the flash_attn ops are not compatible
* with the PyTorch library bindings. To get around that we use
* `make_pytorch_shim` which creates a lambda that exponses the API using
* PyTorch compatible types to the types, then converts them to the types
* expected by the flash_attn ops. This shims allows us to make minimal changes
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
*
* The `pytorch_library_compatible_type` struct is used to map from the
* flash_attn ops types to a PyTorch library compatible one. The main issues is
* that the following types are not support by PyTorch library bindings:
* - `int`
* - `float`
* - `std::optional<T> &`
* - `std::optional<const at::Tensor> &`
* So we convert them to (respectively):
* - `int64_t`
* - `double`
* - `const std::optional<T>&`
* - `const std::optional<at::Tensor>&`
*/

template <typename T>
struct pytorch_library_compatible_type {
using type = T;
static T convert_from_type(T arg) { return arg; }
};

template <typename T>
using pytorch_library_compatible_type_t =
typename pytorch_library_compatible_type<T>::type;

template <typename T>
T convert_from_pytorch_compatible_type(
pytorch_library_compatible_type_t<T> arg) {
return pytorch_library_compatible_type<T>::convert_from_type(arg);
}

// Map `std::optional<T> &` -> `const std::optional<T>&`
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
// the optional container)
template <typename T>
struct pytorch_library_compatible_type<std::optional<T>&> {
using type = const std::optional<T>&;
static std::optional<T>& convert_from_type(const std::optional<T>& arg) {
return const_cast<std::optional<T>&>(arg);
}
};

// Map `std::optional<T>` ->
// `std::optional<pytorch_library_compatible_type_t<T>>`
// (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
template <typename T>
struct pytorch_library_compatible_type<std::optional<T>> {
using type = std::optional<pytorch_library_compatible_type_t<T>>;
static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(
std::optional<T> arg) {
return arg;
}
};

// Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
template <>
struct pytorch_library_compatible_type<std::optional<const at::Tensor>&> {
using type = const std::optional<at::Tensor>&;
static std::optional<const at::Tensor>& convert_from_type(
const std::optional<at::Tensor>& arg) {
return const_cast<std::optional<const at::Tensor>&>(
reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
}
};

// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(),
"int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(),
"int64_t value is too small to be converted to int");
return arg;
}
};

// Map `float` -> `double`
template <>
struct pytorch_library_compatible_type<float> {
using type = double;
static float convert_from_type(double arg) {
TORCH_CHECK(std::abs(arg) <= std::numeric_limits<float>::max(),
"double value is too large to be converted to float");
return arg;
}
};

//
// Shim Utils
//

template <typename Ret, typename... Args>
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
return [fun](pytorch_library_compatible_type_t<Args>... args) {
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
};
}
Loading