Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ void get_cutlass_moe_mm_problem_sizes(
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt);

void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);

void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
Expand Down
108 changes: 96 additions & 12 deletions csrc/quantization/w8a8/cutlass/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include "dispatch_utils.h"

#include <iostream>

constexpr uint64_t THREADS_PER_EXPERT = 512;
Expand Down Expand Up @@ -114,22 +116,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
const bool swap_ab) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());

const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();

if (swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
}
});
}
} // namespace

Expand All @@ -153,6 +150,93 @@ void get_cutlass_moe_mm_problem_sizes_caller(
may_swap_ab);
}

template <bool SWAP_AB>
__global__ void compute_problem_sizes_from_expert_offsets(
const int64_t* __restrict__ expert_first_token_offset,
int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2,
const int num_experts, const int n, const int k) {
int const expert_id = blockIdx.x * blockDim.x + threadIdx.x;
if (expert_id >= num_experts) {
return;
}

int64_t const m64 = expert_first_token_offset[expert_id + 1] -
expert_first_token_offset[expert_id];
int32_t const m = static_cast<int32_t>(m64);

int32_t* ps1 = problem_sizes1 + expert_id * 3;
int32_t* ps2 = problem_sizes2 + expert_id * 3;

if constexpr (!SWAP_AB) {
// [M, 2*N, K]
ps1[0] = m;
ps1[1] = 2 * n;
ps1[2] = k;
// [M, K, N]
ps2[0] = m;
ps2[1] = k;
ps2[2] = n;
} else {
// swap logical M/N in the problem shape
// [2*N, M, K]
ps1[0] = 2 * n;
ps1[1] = m;
ps1[2] = k;
// [K, M, N]
ps2[0] = k;
ps2[1] = m;
ps2[2] = n;
}
}

void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab) {
TORCH_CHECK(expert_first_token_offset.is_cuda(),
"expert_first_token_offset must be a CUDA tensor");
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
"expert_first_token_offset must be int64");

TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
"problem_sizes must be CUDA tensors");
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
problem_sizes2.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
"problem_sizes must be contiguous");
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
"problem_sizes must be 2D tensors");
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
"problem_sizes second dim must be 3");
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
"problem_sizes1 and problem_sizes2 must have same shape");

int64_t const num_experts64 = problem_sizes1.size(0);
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
"expert_first_token_offset must have num_experts + 1 elements");
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");

int const num_experts = static_cast<int>(num_experts64);
auto stream = at::cuda::getCurrentCUDAStream(
expert_first_token_offset.device().index());

int const threads = (num_experts < 256) ? num_experts : 256;
int const blocks = (num_experts + threads - 1) / threads;

auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();

VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes_from_expert_offsets<SwapAB>
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
num_experts, static_cast<int>(n),
static_cast<int>(k));
});
}

void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
Expand Down
24 changes: 24 additions & 0 deletions csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ void get_cutlass_moe_mm_problem_sizes_caller(
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt);

void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);

void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
Expand Down Expand Up @@ -322,6 +327,25 @@ void get_cutlass_moe_mm_problem_sizes(
version_num, ". Required capability: 90, 100, or 120");
}

void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}

void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
Expand Down
11 changes: 11 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes);

// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);

// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
Expand Down
19 changes: 19 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,25 @@ def get_cutlass_moe_mm_problem_sizes(
)


def get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
n: int,
k: int,
swap_ab: bool,
):
"""Compute per-expert (M, N, K) problem sizes from expert_first_token_offset"""
return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset,
problem_sizes1,
problem_sizes2,
n,
k,
swap_ab,
)


def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
Expand Down
68 changes: 17 additions & 51 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,7 @@ def run_cutlass_moe_fp8(
assert global_num_experts != -1
assert a1q_scale is not None

if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
Comment on lines -111 to -115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify correctness for removing the expert_map logic here. I assume this works because moe_permute already handles the mapping, but I'm not sure. I think you should test accuracy with EP and EPLB to properly exercise this case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9462|±  |0.0062|
|     |       |strict-match    |     5|exact_match||0.9454|±  |0.0063|

Tested with EPLB, added in the PR description as well

else:
local_topk_ids = topk_ids

topk = local_topk_ids.size(1)
topk = topk_ids.size(1)
local_E = w1.size(0)

if use_batched_format:
Expand Down Expand Up @@ -164,12 +156,8 @@ def run_cutlass_moe_fp8(
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else:
problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)

num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
Expand All @@ -182,11 +170,12 @@ def run_cutlass_moe_fp8(
expert_map,
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_first_token_offset[:-1]

ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
swap_ab = a1q.size(0) <= 64
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab
Comment on lines +175 to +176
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use it or cutlass moe fp4 too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems out of this PR's scope, I can test it and if could be used, I will have a following up PR for that.

)
expert_offsets = expert_first_token_offset[:-1]

if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by
Expand Down Expand Up @@ -240,9 +229,7 @@ def run_cutlass_moe_fp8(
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
expert_first_token_offset=(
expert_first_token_offset if expert_map is not None else None
),
expert_first_token_offset=expert_first_token_offset,
)


Expand Down Expand Up @@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8(
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
)

# Translate info from expert_map to topk_ids
if expert_map is not None:
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
else:
local_topk_ids = topk_ids

topk = local_topk_ids.size(1)
topk = topk_ids.size(1)
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
Expand All @@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8(
)
mm2_out = _resize_cache(workspace2, (M * topk, K))

problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)

num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
Expand All @@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8(
expert_map,
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_first_token_offset[:-1]

# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids,
problem_sizes1,
problem_sizes2,
global_num_experts,
N,
K,
force_swap_ab=True,
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True
)
expert_offsets = expert_first_token_offset[:-1]

ops.cutlass_w4a8_moe_mm(
mm1_out,
Expand Down Expand Up @@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8(
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
expert_first_token_offset=(
expert_first_token_offset if expert_map is not None else None
),
expert_first_token_offset=expert_first_token_offset,
)


Expand Down