diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 55e659679701..46e05f18c8f0 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -15,15 +15,13 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, const scalar_t& y) { return act_first ? ACT_FN(x) * y : x * ACT_FN(y); } -// Activation and gating kernel template. template -__global__ void act_and_mul_kernel( +__device__ void _act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { - const int64_t token_idx = blockIdx.x; + const int d, const int64_t token_idx) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); @@ -31,6 +29,18 @@ __global__ void act_and_mul_kernel( } } +// Activation and gating kernel template. + +template +__global__ void act_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + _act_and_mul_kernel(out, input, d, token_idx); +} + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -223,3 +233,72 @@ void gelu_quick(torch::Tensor& out, // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); } + +namespace vllm { +// Batched act_and_mul kernel template +template +__global__ void batched_act_and_mul_kernel( + scalar_t* out, // [B, max_tokens, d] + const scalar_t* input, // [B, max_tokens, 2, d] + const int32_t* valid_tokens_array, // [B] + const int d, const int max_num_tokens) { + const int64_t batch_idx = blockIdx.x; + const int64_t num_tokens = valid_tokens_array[batch_idx]; + if (num_tokens == 0) { + return; + } + + int const col_offset = blockIdx.y * blockDim.x; + scalar_t* __restrict__ b_out = + &out[batch_idx * max_num_tokens * d + col_offset]; + const scalar_t* __restrict__ b_in = + &input[batch_idx * max_num_tokens * d * 2 + col_offset]; + + int token_idx = 0; + const int tidx = threadIdx.x; + while (token_idx < num_tokens) { + if (col_offset + tidx < d) { + const scalar_t x = VLLM_LDG(&b_in[tidx]); + const scalar_t y = VLLM_LDG(&b_in[tidx + d]); + b_out[tidx] = compute(x, y); + } + + b_out += d; + b_in += (2 * d); + + ++token_idx; + } +} +} // namespace vllm + +// Launch batched activation and gating kernel. +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ + int64_t const batch_size = input.size(0); \ + int64_t const max_num_tokens = input.size(1); \ + int const d = input.size(2) / 2; \ + int const block_size = std::min(d, 1024); \ + int const blocks_per_row = ((d - 1) / block_size) + 1; \ + dim3 grid(batch_size, blocks_per_row); \ + dim3 block(block_size); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "batched_act_and_mul_kernel", [&] { \ + vllm::batched_act_and_mul_kernel, \ + ACT_FIRST> \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + valid_tokens_array.data_ptr(), d, max_num_tokens); \ + }); + +void batched_silu_and_mul(torch::Tensor& out, // [B, max_tokens, d] + torch::Tensor& input, // [B, max_tokens, 2, d] + torch::Tensor& valid_tokens_array) // [B] +{ + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + TORCH_CHECK(valid_tokens_array.dtype() == torch::kInt32); + LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f02f5083ac19..5710384f1613 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,6 +130,9 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +void batched_silu_and_mul(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& valid_tokens_array); + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1a1896b4c1ee..daf962e44d33 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -111,6 +111,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); + ops.def( + "batched_silu_and_mul(Tensor! result, Tensor input, Tensor " + "valid_tokens_array) -> ()"); + ops.impl("batched_silu_and_mul", torch::kCUDA, &batched_silu_and_mul); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 29c5e70a8ba8..33c479cf13ab 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -14,6 +14,21 @@ SiluAndMul) from vllm.platforms import current_platform + +def ref_batched_silu_mul(x, out, valid_tokens_array): + """ + Reference implementation of batched silu_and_mul + """ + valid_tokens_array = valid_tokens_array.to("cpu") + batch_size = x.size(0) + for b in range(batch_size): + # num valid tokens + n = valid_tokens_array[b] + if n == 0: + continue + torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :]) + + DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing @@ -106,3 +121,42 @@ def test_activation( out = torch.empty_like(x) opcheck(fn, (out, x)) + + +## Test Batched Implementaion #### + +BATCH_SIZES = [1, 13, 26, 32] +NUM_TOKENS = [7, 37, 64, 4096] +D = [128, 256, 384, 512, 1024, 1536, 13824] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_batched_silu_mul(batch_size: int, num_tokens: int, d: int, + dtype: torch.dtype): + + input = torch.randn( + (batch_size, num_tokens, d), device="cuda", dtype=dtype) / 10.0 + + out = torch.empty((batch_size, num_tokens, d // 2), + device="cuda", + dtype=dtype) + + ref_out = out.clone() + + # valid num_tokens per batch + valid_num_tokens = torch.randint(low=0, + high=num_tokens + 1, + size=(batch_size, ), + device="cuda").to(dtype=torch.int32) + + # reference + ref_batched_silu_mul(input, ref_out, valid_num_tokens) + + # impl + torch.ops._C.batched_silu_and_mul(out, input, valid_num_tokens) + + torch.testing.assert_close(ref_out, out) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 2d7cf39a8cca..037815ec15ee 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -15,7 +15,8 @@ from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -169,15 +170,16 @@ def make(config: TestConfig, rank) -> "TestTensors": block_k = block_size[1] _, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k) - topk_ids = torch.randint( - low=0, - high=config.num_experts, - size=(m, topk), - device=torch.cuda.current_device()).to(dtype=torch.int64) + score = torch.randn((m, config.num_experts), + device="cuda", + dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(rank_tokens, score, topk, False) - topk_weights = torch.randn(topk_ids.shape, - dtype=torch.float32, - device=torch.cuda.current_device()) + # overwrite topk_ids to distribute evenly. + topk_ids = torch.empty((m, topk), device="cpu", dtype=torch.int64) + for mi in range(m): + topk_ids[mi] = torch.randperm(config.num_experts)[:topk] + topk_ids = topk_ids.to(device=torch.cuda.current_device()) return TestTensors(rank_tokens=rank_tokens, rank_token_scales=rank_token_scales, @@ -459,6 +461,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, w2, w1_scale, w2_scale) +TOPKS = [2, 6] MNKs = [ (1, 128, 2560), (2, 128, 2560), @@ -467,9 +470,16 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, (45, 512, 2560), (64, 1024, 2560), (222, 1024, 2560), + (45, 128, 2560), + (64, 128, 2560), + (222, 128, 2560), + (45, 2048, 2560), + (64, 2048, 2560), + (222, 2048, 2560), + (333, 2048, 2560), + (444, 2048, 2560), ] -# Fix tests for USE_FP8_DISPATCH=True -USE_FP8_DISPATCH = [False] +USE_FP8_DISPATCH = [False, True] @pytest.mark.parametrize("mnk", MNKs) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py new file mode 100644 index 000000000000..4dbff8f185a8 --- /dev/null +++ b/tests/kernels/moe/test_masked_kernels.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test for masked utility kernels. +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.masked_kernels import ( + invoke_masked_silu_and_mul, masked_per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.platforms import current_platform + + +def ref_silu_mul(x, out, valid_tokens_array): + + valid_tokens_array = valid_tokens_array.to("cpu") + batch_size = x.size(0) + for b in range(batch_size): + # num valid tokens + n = valid_tokens_array[b] + if n == 0: + continue + torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :]) + + +def ref_per_token_group_quant( + x: torch.Tensor, x_q: torch.Tensor, valid_tokens_array: torch.Tensor, + group_size: int, + column_major_scales: bool) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape == x_q.shape + + # make scales tensor + B, NUM_TOKENS, HIDDEN_SIZE = x.shape + x_q_s = torch.empty((B, NUM_TOKENS, HIDDEN_SIZE // group_size), + device="cuda", + dtype=torch.float32) + + valid_tokens_array = valid_tokens_array.to("cpu") + batch_size = x.size(0) + for b in range(batch_size): + # num valid tokens + n = valid_tokens_array[b] + if n == 0: + continue + x_slice = x[b, :n, :] + xq_slice, xqs_slice = per_token_group_quant_fp8( + x_slice, group_size, column_major_scales=column_major_scales) + x_q[b, :n, :].copy_(xq_slice) + x_q_s[b, :n, :].copy_(xqs_slice) + + return x_q, x_q_s + + +BATCH_SIZES = [1, 13, 26, 32] +NUM_TOKENS = [7, 37, 64, 4096] + +## Tests for masked per_token_group_quant_fp8 #### + +HIDDEN_SIZES = [128, 256, 384, 512, 1024] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("column_major_scales", [True]) +def test_masked_per_token_group_quant_fp8(batch_size: int, num_tokens: int, + hidden_size: int, dtype: torch.dtype, + column_major_scales: bool): + + DEEPGEMM_BLOCK_SIZE = 128 + + input = torch.randn( + (batch_size, num_tokens, hidden_size), device="cuda", + dtype=dtype) / 10.0 + + out_q = torch.randn((batch_size, num_tokens, hidden_size), device="cuda") + out_q = out_q.to(dtype=current_platform.fp8_dtype()) + + ref_out_q = torch.empty_like(out_q) + ref_out_q.copy_(out_q) + + # valid num_tokens per batch + valid_num_tokens = torch.randint(low=0, + high=num_tokens + 1, + size=(batch_size, ), + device="cuda").to(torch.int32) + + # Reference + ref_out_q, ref_out_scales = ref_per_token_group_quant( + x=input, + x_q=ref_out_q, + valid_tokens_array=valid_num_tokens, + group_size=DEEPGEMM_BLOCK_SIZE, + column_major_scales=column_major_scales) + + # Impl + out_q, out_scales = masked_per_token_group_quant_fp8( + x=input, + x_q=out_q, + valid_tokens_array=valid_num_tokens, + group_size=DEEPGEMM_BLOCK_SIZE, + column_major_scales=column_major_scales) + + torch.testing.assert_close(ref_out_q, out_q) + + valid_num_tokens_cpu = valid_num_tokens.to(device="cpu") + for b in range(valid_num_tokens_cpu.size(0)): + n = valid_num_tokens_cpu[b] + torch.testing.assert_close(ref_out_scales[b, :n, :], + out_scales[b, :n, :]) + + +## Tests for masked silu_and_mul #### + +HIDDEN_SIZES = [124, 1024, 2176, 2816] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, + dtype: torch.dtype): + + input = torch.randn( + (batch_size, num_tokens, hidden_size), device="cuda", + dtype=dtype) / 10.0 + + out = torch.empty((batch_size, num_tokens, hidden_size // 2), + device="cuda", + dtype=dtype) + + ref_out = torch.empty_like(out) + ref_out.copy_(out) + + # valid num_tokens per batch + valid_num_tokens = torch.randint(low=0, + high=num_tokens + 1, + size=(batch_size, ), + device="cuda").to(torch.int32) + + # reference + ref_silu_mul(input, ref_out, valid_num_tokens) + + # impl + invoke_masked_silu_and_mul(out, input, valid_num_tokens) + + torch.testing.assert_close(ref_out, out) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 5492399efdf8..93a211cad416 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -2,18 +2,181 @@ import importlib.util from typing import Optional +import triton +import triton.language as tl import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.masked_kernels import ( + masked_per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache logger = init_logger(__name__) has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +@triton.jit +def _per_token_group_quant_fp8_3d( + # Pointers ------------------------------------------------------------ + y_ptr, # FP16 activations (E, T, H) + y_q_ptr, # FP8 quantized activations (E, T, H) + + y_s_ptr, # FP32 scales (E, T, G) + counts_ptr, # INT32 number of tokens per expert (E) + + # Sizes --------------------------------------------------------------- + E: tl.constexpr, # num_experts + T: tl.constexpr, # max_num_tokens + H: tl.constexpr, # hidden dimension + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + + # Strides for y (elements) ------------------------------------------- + stride_y_e, + stride_y_t, + stride_y_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, +): + """Dynamic FP8 quantisation over a 3‑D tensor laid out **(E, T, H)**. + + * Each program instance handles **one** `GROUP_SIZE`‑length slice along H + for a single (expert *e*, token *t*). + * Scales are produced with shape **(E, T, G)** where + `G = H // GROUP_SIZE` and with *element* strides + `(T*G, 1, T)` so that the *token* dimension is the fastest‑varying in + memory – matching the downstream reshape you showed. + * All strides are expressed **in elements**, not bytes. + """ + + G = H // GROUP_SIZE # groups per hidden dim + + # ----------------------- map program id -> (e, g) -------------------- + pid = tl.program_id(0) + e = pid // G + g = pid % G + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int32) + + # block for H dimension + cols = tl.arange(0, BLOCK) + mask_h = cols < BLOCK + + # iterate over tokens for this (expert, group) + t = tl.zeros([], tl.int32) + while t < n_tokens: + base_y_offset = e * stride_y_e + t * stride_y_t + g * GROUP_SIZE * stride_y_h + base_yq_offset = e * stride_yq_e + t * stride_yq_t + g * GROUP_SIZE * stride_yq_h + base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g + + mask = mask_h + y = tl.load(y_ptr + base_y_offset + cols * stride_y_h, + mask=mask, other=0.0).to(tl.float32) + + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, + y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset, y_s) + + t += 1 + + +def quant_fp8_3d( + y: torch.Tensor, # (E, T, H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + fp8_dtype = torch.float8_e4m3fn, + eps: float = 1e-6, +): + """Quantize y into FP8 with per‑(expert, token, group) scales. + + Only the first `tokens_per_expert[e]` tokens are quantized per expert; + the remaining positions in each (E, T, H) slice are treated as padding. + + Returns `(y_q, y_s)` where + * `y_q` is the FP8 tensor, same shape and **standard PyTorch order** as *y*. + * `y_s` has shape `(E, T, H // group_size)` and element strides + `(T * G, 1, T)` so that the *token* dimension is contiguous. + """ + + assert y.ndim == 3, "y must be (E, T, H)" + E, T, H = y.shape + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ + "tokens_per_expert must be shape (E,)" + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # ---------------- allocate outputs ---------------------------------- + y_q = torch.empty_like(y, dtype=fp8_dtype) + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + + # allocate scale buffer with proper shape and stride + y_s = torch.empty_strided((E, T, G), (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, device=y.device) + + # ---------------- stride bookkeeping (elements, not bytes) ---------- + stride_y_e, stride_y_t, stride_y_h = y.stride() + + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + + # stride for tokens_per_expert (elements) + stride_cnt_e = tokens_per_expert.stride()[0] + + # static grid over experts and H-groups; tokens loop is internal to the kernel + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = -f_info.max + + _per_token_group_quant_fp8_3d[grid]( + y, y_q, y_s, tokens_per_expert, + E, T, H, group_size, + stride_y_e, stride_y_t, stride_y_h, + stride_yq_e, stride_yq_t, stride_yq_h, + stride_ys_e, stride_ys_t, stride_ys_g, + stride_cnt_e, + eps, fp8_min, fp8_max, + BLOCK=group_size, + num_warps=4, + ) + + return y_q, y_s + class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 @@ -86,6 +249,8 @@ def apply( ): import deep_gemm as dg assert hidden_states.ndim == 3 + assert(w1_zp is None and w2_zp is None) + assert(a2_scale is None) a1q = hidden_states _, N, K = w1.size() @@ -109,19 +274,12 @@ def apply( masked_m=expert_num_tokens, expected_m=expected_m) - # TODO (varun) [Optimization]: Use a batched version of activation. - # Similarly for the quant below. - self.activation(activation, workspace2, workspace1.view(-1, N)) - - w2_hidden_size = workspace2.size(-1) - workspace2 = workspace2.view(-1, w2_hidden_size) + self.masked_activation(activation, workspace2, workspace1, + expert_num_tokens) - a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(workspace2, - self.block_shape[1], - column_major_scales=False) - a2q = a2q.view(E, max_num_tokens, -1) - a2q_scale = a2q_scale.view(E, max_num_tokens, -1) + a2q, a2q_scale = quant_fp8_3d(workspace2, + tokens_per_expert=expert_num_tokens, + group_size=self.block_shape[1]) dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a12cfafd42ab..64cf1f2b221e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -734,10 +734,8 @@ def apply( config=config, block_shape=self.block_shape) - # TODO: would be nice to use expert_num_tokens here to reduce - # garbage compute - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + self.masked_activation(activation, intermediate_cache2, + intermediate_cache1, expert_num_tokens) ic2_hidden_size = intermediate_cache2.size(-1) intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f2175886..98733f101acb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -45,7 +45,8 @@ from .pplx_prepare_finalize import PplxPrepareAndFinalize if has_deepep: from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, + DeepEPLLPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -377,6 +378,12 @@ def init_prepare_finalize(self, moe: MoEConfig, all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) + # Note : We may want to use FP8 dispatch even otherwise just to + # reduce datamovement + use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() + and act_quant_block_size[1] + == DEEPEP_QUANT_BLOCK_SIZE) + # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( @@ -386,7 +393,7 @@ def init_prepare_finalize(self, moe: MoEConfig, max_tokens_per_rank=moe.max_num_tokens, quant_dtype=quant_dtype, block_shape=act_quant_block_size, - use_fp8_dispatch=False, + use_fp8_dispatch=use_fp8_dispatch, ) self.topk_indices_dtype = None diff --git a/vllm/model_executor/layers/fused_moe/masked_kernels.py b/vllm/model_executor/layers/fused_moe/masked_kernels.py new file mode 100644 index 000000000000..68f28b11fa1f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/masked_kernels.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Masked kernels used in the fused_moe operation. In the batched versions +of ModularKernel.FusedMoEPermuteExpertsUnpermute, where batch_size +is the number-of-experts, only some tokens in each batch are valid. +The kernels in this file, account for that and only operate on the +valid tokens. +""" + +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _do_per_token_group_quant_fp8, _do_per_token_group_quant_fp8_colmajor) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +## Masked Per Token Quant #### + + +@triton.jit +def _masked_per_token_group_quant_fp8( + valid_tokens_array, + # Batch dimension strides + stride_yb, + stride_yqb, + stride_ysb, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr): + + batch_id = tl.program_id(axis=0) + num_tokens = tl.load(valid_tokens_array + batch_id) + if num_tokens == 0: + # early exit + return + + groups_per_row = y_num_columns // group_size + valid_num_groups = num_tokens * groups_per_row + group_id = tl.program_id(axis=1) + if group_id >= valid_num_groups: + # early exit + return + + y_ptr = y_ptr + batch_id * stride_yb + y_q_ptr = y_q_ptr + batch_id * stride_yqb + y_s_ptr = y_s_ptr + batch_id * stride_ysb + + _do_per_token_group_quant_fp8( + group_id, # group id + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + +@triton.jit +def _masked_per_token_group_quant_fp8_colmajor( + valid_tokens_array, + # Batch strides + stride_yb, + stride_yqb, + stride_ysb, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + batch_id = tl.program_id(axis=0) + num_tokens = tl.load(valid_tokens_array + batch_id) + if num_tokens == 0: + # early exit + return + + group_id = tl.program_id(axis=1) + groups_per_row = y_num_columns // group_size + valid_num_groups = num_tokens * groups_per_row + if group_id >= valid_num_groups: + # early exit + return + + y_ptr = y_ptr + batch_id * stride_yb + y_q_ptr = y_q_ptr + batch_id * stride_yqb + y_s_ptr = y_s_ptr + batch_id * stride_ysb + + _do_per_token_group_quant_fp8_colmajor( + group_id, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + +def masked_per_token_group_quant_fp8( + x: torch.Tensor, # [B, MAX_TOKENS, HIDDEN_SIZE] + valid_tokens_array: torch.Tensor, # [B] + group_size: int, + column_major_scales: bool, + x_q: Optional[torch.Tensor] = None, # [B, MAX_TOKENS, HIDDEN_SIZE] + eps: float = 1e-10 +) -> tuple[torch.Tensor, torch.Tensor]: + + assert x.ndim == 3 + assert (x.size(-1) % group_size == 0), ( + f"the last dimension of `x` {x.size(-1)} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + dtype = current_platform.fp8_dtype() + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + assert x_q is None or x_q.shape == x.shape + if x_q is None: + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + + B, MAX_TOKENS, HIDDEN_SIZE = x.shape + shape = (B, MAX_TOKENS, HIDDEN_SIZE // group_size) + if column_major_scales: + cms_shape = (shape[0], shape[2], shape[1]) + x_s = torch.empty(cms_shape, device=x.device, dtype=torch.float32) + x_s = x_s.permute(0, 2, 1) + else: + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + M = (MAX_TOKENS * HIDDEN_SIZE) // group_size + N = group_size + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + grid = (B, M) + + if column_major_scales: + _masked_per_token_group_quant_fp8_colmajor[grid]( + valid_tokens_array, + x.stride(0), + x_q.stride(0), + x_s.stride(0), + x, + x_q, + x_s, + group_size, + x.size(2), # num_columns + x.stride(1), # row_stride + x_s.stride(2), # col_stride + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _masked_per_token_group_quant_fp8[grid]( + valid_tokens_array, + x.stride(0), + x_q.stride(0), + x_s.stride(0), + x, + x_q, + x_s, + group_size, + x.size(2), # num_columns + x.stride(1), # row_stride + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +## Batched Silu and Mul Kernel #### + + +@triton.jit +def silu(x_tile): + return x_tile * (1.0 / (1.0 + tl.exp(-x_tile))) + + +@triton.jit +def silu_and_mul( + pid_d, + output, # [M, D] + input, # [M, D * 2] + stride_om, + stride_im, + M, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + compute_type: tl.constexpr): + + remaining_d = D - (pid_d * BLOCK_D) + + offs_m = tl.arange(0, BLOCK_M)[:, None] + mask_m = offs_m < M + + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < remaining_d + + input_ptrs = input + offs_m * stride_im + pid_d * BLOCK_D + offs_d + output_ptrs = output + offs_m * stride_om + pid_d * BLOCK_D + offs_d + + mask_tile = mask_m & mask_d + x_tile = tl.load(input_ptrs, mask=mask_tile, + other=0.0).to(dtype=tl.float32) + + y_tile = tl.load(input_ptrs + D, mask=mask_tile, other=0.0) + + # silu and mul + out_tile = silu(x_tile).to(dtype=compute_type) + out_tile = out_tile * y_tile + + tl.store(output_ptrs, out_tile, mask=mask_tile) + + +@triton.jit +def masked_silu_and_mul_kernel( + output, # [B, MAX_NUM_TOKENS, D] + input, # [B, MAX_NUM_TOKENS, D * 2] + valid_tokens_array, # [B] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): + + batch_id = tl.program_id(axis=0) + num_tokens = tl.load(valid_tokens_array + batch_id) + if num_tokens == 0: + # early exit + return + + pid_m = tl.program_id(axis=1) + cta_m_start = pid_m * BLOCK_M + if cta_m_start >= num_tokens: + # early exit + return + + cta_input_ptr = input + batch_id * stride_ie + cta_m_start * stride_im + cta_output_ptr = output + batch_id * stride_oe + cta_m_start * stride_om + + cta_m_size = min(BLOCK_M, num_tokens - cta_m_start) + + pid_d = tl.program_id(axis=2) + silu_and_mul( + pid_d, + cta_output_ptr, + cta_input_ptr, + stride_om, + stride_im, + cta_m_size, # M + D, + BLOCK_M, + BLOCK_D, + compute_type) + + +def invoke_masked_silu_and_mul( + output: torch.Tensor, #[B, MAX_TOKENS, D] + input: torch.Tensor, #[B, MAX_TOKENS, D * 2] + valid_tokens_array: torch.Tensor): + assert input.ndim == 3 + batch_size, max_num_tokens, D = output.size() + + BLOCK_D = 1024 + BLOCK_M = 1 + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16 + }[output.dtype] + + grid = (batch_size, triton.cdiv(max_num_tokens, + BLOCK_M), triton.cdiv(D, BLOCK_D)) + masked_silu_and_mul_kernel[grid](output, input, valid_tokens_array, + output.stride(0), output.stride(1), + input.stride(0), input.stride(1), + compute_tl_dtype, D, BLOCK_M, BLOCK_D) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ed3b6b8a1af4..300f56deeb5a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -215,6 +215,27 @@ def workspace_shapes( """ raise NotImplementedError + def masked_activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor, + expert_num_tokens: torch.Tensor) -> None: + """ + Given inputs and outputs of shape + [num_experts, max_tokens, hidden_size], and expert_num_tokens/mask + of shape [E], perform act_and_mul only on the inputs that are + actually valid. + Note that expert_num_tokens[i] is the number of tokens that are + actually valid for expert i. + """ + assert output.ndim == 3 and input.ndim == 3 + assert output.size(-1) * 2 == input.size(-1) + E, _, _ = input.shape + assert expert_num_tokens.size(0) == E, ( + f"expert_num_tokens.size(0)({expert_num_tokens.size(0)}) != E({E})" + ) + if activation != "silu": + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + torch.ops._C.batched_silu_and_mul(output, input, expert_num_tokens) + def activation(self, activation: str, output: torch.Tensor, input: torch.Tensor) -> None: assert output.size(-1) * 2 == input.size(-1) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 754650ebeffb..798562ba0741 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -245,7 +245,8 @@ def block_quant_to_tensor_quant( @triton.jit -def _per_token_group_quant_fp8( +def _do_per_token_group_quant_fp8( + g_id, # group id # Pointers to inputs and output y_ptr, y_q_ptr, @@ -268,8 +269,7 @@ def _per_token_group_quant_fp8( """ groups_per_row = y_num_columns // group_size - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) + # Map the group ID to the row of X and Y it should compute. row = g_id // groups_per_row row_g_id = g_id % groups_per_row @@ -296,7 +296,47 @@ def _per_token_group_quant_fp8( @triton.jit -def _per_token_group_quant_fp8_colmajor( +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr): + + # group ID + g_id = tl.program_id(axis=0) + _do_per_token_group_quant_fp8( + g_id, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + +@triton.jit +def _do_per_token_group_quant_fp8_colmajor( + g_id, # group_id # Pointers to inputs and output y_ptr, y_q_ptr, @@ -321,8 +361,7 @@ def _per_token_group_quant_fp8_colmajor( """ groups_per_row = y_num_columns // group_size - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) + # Map the group id to the row of X and Y it should compute. row = g_id // groups_per_row row_g_id = g_id % groups_per_row @@ -357,6 +396,48 @@ def _per_token_group_quant_fp8_colmajor( tl.store(y_s_ptr, y_s) +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr): + + g_id = tl.program_id(axis=0) + _do_per_token_group_quant_fp8_colmajor( + g_id, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + def per_token_group_quant_fp8( x: torch.Tensor, group_size: int,