Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "2eeb05da5b801b34114b6b394dcef836fc9a7cc9" CACHE STRING "CUTLASS revision to use")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit is from a new branch vllm_xpu_cutlass of https://github.com/intel/cutlass-sycl.


# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
FetchContent_Declare(
Expand Down
11 changes: 4 additions & 7 deletions csrc/xpu/cutlass_kernels/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,18 @@ namespace gpu::cutlass_kernel {

namespace grouped_gemm {
void kernel_functor(sycl::queue& stream, void* ptr_A, void* ptr_B, void* ptr_D,
void* ptr_alpha, void* ptr_beta, void* offset, int64_t N,
int64_t K, int64_t groups);
void* offset, int32_t N, int32_t K, int32_t groups);
}

/* gemm2(group_A, w2, output, offset) */

at::Tensor grouped_gemm_func(at::Tensor& ptr_A, at::Tensor& ptr_B,
at::Tensor& ptr_D, at::Tensor& ptr_alpha,
at::Tensor& ptr_beta, at::Tensor& offset,
at::Tensor& ptr_D, at::Tensor& tokens_per_expert,
int64_t N, int64_t K, int64_t groups) {
auto& dpcpp_queue = vllm::xpu::vllmGetQueue();
grouped_gemm::kernel_functor(dpcpp_queue, ptr_A.data_ptr(), ptr_B.data_ptr(),
ptr_D.data_ptr(), ptr_alpha.data_ptr(),
ptr_beta.data_ptr(), offset.data_ptr(), N, K,
groups);
ptr_D.data_ptr(), tokens_per_expert.data_ptr(), (int32_t)N, (int32_t)K,
(int32_t)groups);
return ptr_D;
}

Expand Down
Loading