From 4d518d150c701dbe91d27f84dcd249f9c6ff9da7 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 19:20:14 -0700 Subject: [PATCH 01/26] add batched silu mul Signed-off-by: Varun Sundar Rabindranath --- .../layers/fused_moe/batched_utils.py | 120 ++++++++++++++++++ .../layers/fused_moe/modular_kernel.py | 19 +++ 2 files changed, 139 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/batched_utils.py diff --git a/vllm/model_executor/layers/fused_moe/batched_utils.py b/vllm/model_executor/layers/fused_moe/batched_utils.py new file mode 100644 index 000000000000..ccf204965b88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/batched_utils.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Batched utility kernels used in the fused_moe operation. +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def silu(x_tile): + return x_tile * (1.0 / (1.0 + tl.exp(-x_tile))) + + +@triton.jit +def silu_and_mul( + pid_d, + output, # [M, D] + input, # [M, D * 2] + stride_om, + stride_im, + M, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + compute_type: tl.constexpr): + + offs_m = tl.arange(0, BLOCK_M)[:, None] + mask_m = offs_m < M + + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < D + + input_ptrs = input + offs_m * stride_im + pid_d * BLOCK_D + offs_d + output_ptrs = output + offs_m * stride_om + pid_d * BLOCK_D + offs_d + + mask_tile = mask_m & mask_d + + x_tile = tl.load(input_ptrs, mask=mask_tile, + other=0.0).to(dtype=tl.float32) + y_tile = tl.load(input_ptrs + D, mask=mask_tile, other=0.0) + + # silu and mul + out_tile = silu(x_tile).to(dtype=compute_type) + out_tile = out_tile * y_tile + + tl.store(output_ptrs, out_tile, mask=mask_tile) + + +@triton.jit +def batched_silu_and_mul_kernel( + output, # [E, MAX_NUM_TOKENS, D] + input, # [E, MAX_NUM_TOKENS, D * 2] + expert_num_tokens, # [E] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): + + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # early exit + return + + pid_m = tl.program_id(axis=1) + cta_m_start = pid_m * BLOCK_M + if cta_m_start >= e_num_tokens: + # early exit + return + + cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im + cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + + pid_d = tl.program_id(axis=2) + silu_and_mul( + pid_d, + cta_output_ptr, + cta_input_ptr, + stride_om, + stride_im, + cta_m_size, # M + D, + BLOCK_M, + BLOCK_D, + compute_type) + + +def invoke_batched_silu_and_mul( + output: torch.Tensor, #[E, MAX_TOKENS, D] + input: torch.Tensor, #[E, MAX_TOKENS, D * 2] + expert_num_tokens: torch.Tensor): + + num_experts = output.size(0) + max_num_tokens = output.size(1) + D = output.size(2) + + BLOCK_D = 1024 + BLOCK_M = 1 + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16 + }[output.dtype] + + grid = (num_experts, triton.cdiv(max_num_tokens, + BLOCK_M), triton.cdiv(D, BLOCK_D)) + batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens, + output.stride(0), output.stride(1), + input.stride(0), input.stride(1), + compute_tl_dtype, D, BLOCK_M, BLOCK_D) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9409b59982d9..0a7437ffaf54 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -215,6 +215,25 @@ def workspace_shapes( """ raise NotImplementedError + def masked_activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor, + expert_num_tokens: torch.Tensor) -> None: + """ + Given inputs and outputs of shape + [num_experts, max_tokens, hidden_size], and expert_num_tokens/mask + of shape [E], perform act_and_mul only on the inputs that are + actually valid. + Note that expert_num_tokens[i] is the number of tokens that are + actually valid for expert i. + """ + assert output.shape == input.shape + E, _, _ = input.shape + assert expert_num_tokens.size(0) == E, ( + f"expert_num_tokens.size(0)({expert_num_tokens.size(0)}) != E({E})" + ) + assert activation == "silu", "Only silu_and_mul is supported for now." + pass + def activation(self, activation: str, output: torch.Tensor, input: torch.Tensor) -> None: assert output.size(-1) * 2 == input.size(-1) From ea96dddf699566059fbfd56e504739e7ef0adbf7 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 19:34:01 -0700 Subject: [PATCH 02/26] Refactor per_token_group_quant Signed-off-by: Varun Sundar Rabindranath --- .../layers/fused_moe/batched_utils.py | 4 + .../layers/quantization/utils/fp8_utils.py | 91 +++++++++++++++++-- 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_utils.py b/vllm/model_executor/layers/fused_moe/batched_utils.py index ccf204965b88..e1e774143607 100644 --- a/vllm/model_executor/layers/fused_moe/batched_utils.py +++ b/vllm/model_executor/layers/fused_moe/batched_utils.py @@ -8,6 +8,10 @@ from vllm.triton_utils import tl, triton +## Batched Per Token Quant kernel #### + +## Batched Silu and Mul Kernel #### + @triton.jit def silu(x_tile): diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 754650ebeffb..27b065911c08 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -245,7 +245,8 @@ def block_quant_to_tensor_quant( @triton.jit -def _per_token_group_quant_fp8( +def _do_per_token_group_quant_fp8( + g_id, # group id # Pointers to inputs and output y_ptr, y_q_ptr, @@ -268,8 +269,7 @@ def _per_token_group_quant_fp8( """ groups_per_row = y_num_columns // group_size - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) + # Map the group ID to the row of X and Y it should compute. row = g_id // groups_per_row row_g_id = g_id % groups_per_row @@ -295,8 +295,47 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr): + + # group ID + g_id = tl.program_id(axis=0) + _do_per_token_group_quant_fp8( + g_id, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + @triton.jit -def _per_token_group_quant_fp8_colmajor( +def _do_per_token_group_quant_fp8_colmajor( + g_id, # group_id # Pointers to inputs and output y_ptr, y_q_ptr, @@ -321,8 +360,7 @@ def _per_token_group_quant_fp8_colmajor( """ groups_per_row = y_num_columns // group_size - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) + # Map the group id to the row of X and Y it should compute. row = g_id // groups_per_row row_g_id = g_id % groups_per_row @@ -357,6 +395,47 @@ def _per_token_group_quant_fp8_colmajor( tl.store(y_s_ptr, y_s) +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr): + + g_id = tl.program_id(axis=0) + _do_per_token_group_quant_fp8_colmajor( + g_id, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, From abcf846607548d0efdbdf29c3927203602cbfb14 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 21:07:13 -0700 Subject: [PATCH 03/26] add batched per token quant Signed-off-by: Varun Sundar Rabindranath --- .../layers/fused_moe/batched_utils.py | 208 +++++++++++++++++- 1 file changed, 207 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_utils.py b/vllm/model_executor/layers/fused_moe/batched_utils.py index e1e774143607..3d23135698a9 100644 --- a/vllm/model_executor/layers/fused_moe/batched_utils.py +++ b/vllm/model_executor/layers/fused_moe/batched_utils.py @@ -4,11 +4,217 @@ Batched utility kernels used in the fused_moe operation. """ +from typing import Optional + import torch +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _do_per_token_group_quant_fp8, _do_per_token_group_quant_fp8_colmajor) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -## Batched Per Token Quant kernel #### +## Batched Per Token Quant #### + + +@triton.jit +def _batched_per_token_group_quant_fp8( + expert_num_tokens, + stride_ye, + stride_yqe, + stride_yse, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr): + + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # early exit + return + + groups_per_row = y_num_columns // group_size + valid_groups_in_experts = e_num_tokens * groups_per_row + group_id = tl.program_id(axis=1) + if group_id >= valid_groups_in_experts: + # early exit + return + + y_ptr = y_ptr + expert_id * stride_ye + y_q_ptr = y_q_ptr + expert_id * stride_yqe + y_s_ptr = y_s_ptr + expert_id * stride_yse + + _do_per_token_group_quant_fp8( + group_id, # group id + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + +@triton.jit +def _batched_per_token_group_quant_fp8_colmajor( + expert_num_tokens, + stride_ye, + stride_yqe, + stride_yse, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # early exit + return + + groups_per_row = y_num_columns // group_size + valid_groups_in_experts = e_num_tokens * groups_per_row + group_id = tl.program_id(axis=1) + if group_id >= valid_groups_in_experts: + # early exit + return + + y_ptr = y_ptr + expert_id * stride_ye + y_q_ptr = y_q_ptr + expert_id * stride_yqe + y_s_ptr = y_s_ptr + expert_id * stride_yse + + _do_per_token_group_quant_fp8_colmajor( + group_id, + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK) + + +def batched_per_token_group_quant_fp8( + x: torch.Tensor, + x_q: Optional[torch.Tensor], + expert_num_tokens: torch.Tensor, + group_size: int, + column_major_scales: bool, + eps: float = 1e-10) -> tuple[torch.Tensor, torch.Tensor]: + + assert (x.size(-1) % group_size == 0), ( + f"the last dimension of `x` {x.size(-1)} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + dtype = current_platform.fp8_dtype() + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + assert x_q is None or x_q.shape == x.shape + if x_q is None: + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + + E, MAX_TOKENS, HIDDEN_SIZE = x.shape + shape = (E, MAX_TOKENS, HIDDEN_SIZE // group_size) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + if column_major_scales: + x_s = x_s.permute(-1, -2) + + M = (MAX_TOKENS * HIDDEN_SIZE) // group_size + N = group_size + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + grid = (E, M) + + if column_major_scales: + _batched_per_token_group_quant_fp8_colmajor[grid]( + expert_num_tokens, + x.stride(0), + x_q.stride(0), + x_s.stride(0), + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _batched_per_token_group_quant_fp8[grid]( + expert_num_tokens, + x.stride(0), + x_q.stride(0), + x_s.stride(0), + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + ## Batched Silu and Mul Kernel #### From 50162ac83dcd700cabb785b6b57473a92581f7f3 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 21:36:26 -0700 Subject: [PATCH 04/26] batched -> masked Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_batched_utils.py | 10 ++ .../{batched_utils.py => masked_kernels.py} | 108 +++++++++--------- 2 files changed, 66 insertions(+), 52 deletions(-) create mode 100644 tests/kernels/moe/test_batched_utils.py rename vllm/model_executor/layers/fused_moe/{batched_utils.py => masked_kernels.py} (72%) diff --git a/tests/kernels/moe/test_batched_utils.py b/tests/kernels/moe/test_batched_utils.py new file mode 100644 index 000000000000..3c3328ab14e2 --- /dev/null +++ b/tests/kernels/moe/test_batched_utils.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test for batched utility kernels. +""" + + +## Tests for batched silu_and_mul #### +def test_batched_silu_mul(): + pass diff --git a/vllm/model_executor/layers/fused_moe/batched_utils.py b/vllm/model_executor/layers/fused_moe/masked_kernels.py similarity index 72% rename from vllm/model_executor/layers/fused_moe/batched_utils.py rename to vllm/model_executor/layers/fused_moe/masked_kernels.py index 3d23135698a9..a2678878bd5f 100644 --- a/vllm/model_executor/layers/fused_moe/batched_utils.py +++ b/vllm/model_executor/layers/fused_moe/masked_kernels.py @@ -1,7 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Batched utility kernels used in the fused_moe operation. +Masked kernels used in the fused_moe operation. In the batched versions +of ModularKernel.FusedMoEPermuteExpertsUnpermute, where batch_size +is the number-of-experts, only some tokens in each batch are valid. +The kernels in this file, account for that and only operate on the +valid tokens. """ from typing import Optional @@ -18,10 +22,11 @@ @triton.jit def _batched_per_token_group_quant_fp8( - expert_num_tokens, - stride_ye, - stride_yqe, - stride_yse, + valid_tokens_array, + # Batch dimension strides + stride_yb, + stride_yqb, + stride_ysb, # Pointers to inputs and output y_ptr, y_q_ptr, @@ -38,22 +43,22 @@ def _batched_per_token_group_quant_fp8( # Meta-parameters BLOCK: tl.constexpr): - expert_id = tl.program_id(axis=0) - e_num_tokens = tl.load(expert_num_tokens + expert_id) - if e_num_tokens == 0: + batch_id = tl.program_id(axis=0) + num_tokens = tl.load(valid_tokens_array + batch_id) + if num_tokens == 0: # early exit return groups_per_row = y_num_columns // group_size - valid_groups_in_experts = e_num_tokens * groups_per_row + valid_num_groups = num_tokens * groups_per_row group_id = tl.program_id(axis=1) - if group_id >= valid_groups_in_experts: + if group_id >= valid_num_groups: # early exit return - y_ptr = y_ptr + expert_id * stride_ye - y_q_ptr = y_q_ptr + expert_id * stride_yqe - y_s_ptr = y_s_ptr + expert_id * stride_yse + y_ptr = y_ptr + batch_id * stride_yb + y_q_ptr = y_q_ptr + batch_id * stride_yqb + y_s_ptr = y_s_ptr + batch_id * stride_ysb _do_per_token_group_quant_fp8( group_id, # group id @@ -76,10 +81,11 @@ def _batched_per_token_group_quant_fp8( @triton.jit def _batched_per_token_group_quant_fp8_colmajor( - expert_num_tokens, - stride_ye, - stride_yqe, - stride_yse, + valid_tokens_array, + # Batch strides + stride_yb, + stride_yqb, + stride_ysb, # Pointers to inputs and output y_ptr, y_q_ptr, @@ -98,22 +104,22 @@ def _batched_per_token_group_quant_fp8_colmajor( # Meta-parameters BLOCK: tl.constexpr, ): - expert_id = tl.program_id(axis=0) - e_num_tokens = tl.load(expert_num_tokens + expert_id) - if e_num_tokens == 0: + batch_id = tl.program_id(axis=0) + num_tokens = tl.load(valid_tokens_array + batch_id) + if num_tokens == 0: # early exit return groups_per_row = y_num_columns // group_size - valid_groups_in_experts = e_num_tokens * groups_per_row + valid_num_groups = num_tokens * groups_per_row group_id = tl.program_id(axis=1) - if group_id >= valid_groups_in_experts: + if group_id >= valid_num_groups: # early exit return - y_ptr = y_ptr + expert_id * stride_ye - y_q_ptr = y_q_ptr + expert_id * stride_yqe - y_s_ptr = y_s_ptr + expert_id * stride_yse + y_ptr = y_ptr + batch_id * stride_yb + y_q_ptr = y_q_ptr + batch_id * stride_yqb + y_s_ptr = y_s_ptr + batch_id * stride_ysb _do_per_token_group_quant_fp8_colmajor( group_id, @@ -137,9 +143,9 @@ def _batched_per_token_group_quant_fp8_colmajor( def batched_per_token_group_quant_fp8( - x: torch.Tensor, - x_q: Optional[torch.Tensor], - expert_num_tokens: torch.Tensor, + x: torch.Tensor, # [B, MAX_TOKENS, HIDDEN_SIZE] + x_q: Optional[torch.Tensor], # [B, MAX_TOKENS, HIDDEN_SIZE] + valid_tokens_array: torch.Tensor, # [B] group_size: int, column_major_scales: bool, eps: float = 1e-10) -> tuple[torch.Tensor, torch.Tensor]: @@ -158,8 +164,8 @@ def batched_per_token_group_quant_fp8( if x_q is None: x_q = torch.empty_like(x, device=x.device, dtype=dtype) - E, MAX_TOKENS, HIDDEN_SIZE = x.shape - shape = (E, MAX_TOKENS, HIDDEN_SIZE // group_size) + B, MAX_TOKENS, HIDDEN_SIZE = x.shape + shape = (B, MAX_TOKENS, HIDDEN_SIZE // group_size) x_s = torch.empty(shape, device=x.device, dtype=torch.float32) if column_major_scales: x_s = x_s.permute(-1, -2) @@ -171,11 +177,11 @@ def batched_per_token_group_quant_fp8( num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - grid = (E, M) + grid = (B, M) if column_major_scales: _batched_per_token_group_quant_fp8_colmajor[grid]( - expert_num_tokens, + valid_tokens_array, x.stride(0), x_q.stride(0), x_s.stride(0), @@ -195,7 +201,7 @@ def batched_per_token_group_quant_fp8( ) else: _batched_per_token_group_quant_fp8[grid]( - expert_num_tokens, + valid_tokens_array, x.stride(0), x_q.stride(0), x_s.stride(0), @@ -261,9 +267,9 @@ def silu_and_mul( @triton.jit def batched_silu_and_mul_kernel( - output, # [E, MAX_NUM_TOKENS, D] - input, # [E, MAX_NUM_TOKENS, D * 2] - expert_num_tokens, # [E] + output, # [B, MAX_NUM_TOKENS, D] + input, # [B, MAX_NUM_TOKENS, D * 2] + valid_tokens_array, # [B] stride_oe, stride_om, stride_ie, @@ -273,22 +279,22 @@ def batched_silu_and_mul_kernel( BLOCK_M: tl.constexpr, BLOCK_D: tl.constexpr): - expert_id = tl.program_id(axis=0) - e_num_tokens = tl.load(expert_num_tokens + expert_id) - if e_num_tokens == 0: + batch_id = tl.program_id(axis=0) + num_tokens = tl.load(valid_tokens_array + batch_id) + if num_tokens == 0: # early exit return pid_m = tl.program_id(axis=1) cta_m_start = pid_m * BLOCK_M - if cta_m_start >= e_num_tokens: + if cta_m_start >= num_tokens: # early exit return - cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im - cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om + cta_input_ptr = input + batch_id * stride_ie + cta_m_start * stride_im + cta_output_ptr = output + batch_id * stride_oe + cta_m_start * stride_om - cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_m_size = min(BLOCK_M, num_tokens - cta_m_start) pid_d = tl.program_id(axis=2) silu_and_mul( @@ -305,13 +311,11 @@ def batched_silu_and_mul_kernel( def invoke_batched_silu_and_mul( - output: torch.Tensor, #[E, MAX_TOKENS, D] - input: torch.Tensor, #[E, MAX_TOKENS, D * 2] - expert_num_tokens: torch.Tensor): + output: torch.Tensor, #[B, MAX_TOKENS, D] + input: torch.Tensor, #[B, MAX_TOKENS, D * 2] + valid_tokens_array: torch.Tensor): - num_experts = output.size(0) - max_num_tokens = output.size(1) - D = output.size(2) + batch_size, max_num_tokens, D = output.size() BLOCK_D = 1024 BLOCK_M = 1 @@ -322,9 +326,9 @@ def invoke_batched_silu_and_mul( torch.bfloat16: tl.bfloat16 }[output.dtype] - grid = (num_experts, triton.cdiv(max_num_tokens, - BLOCK_M), triton.cdiv(D, BLOCK_D)) - batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens, + grid = (batch_size, triton.cdiv(max_num_tokens, + BLOCK_M), triton.cdiv(D, BLOCK_D)) + batched_silu_and_mul_kernel[grid](output, input, valid_tokens_array, output.stride(0), output.stride(1), input.stride(0), input.stride(1), compute_tl_dtype, D, BLOCK_M, BLOCK_D) From 9f13fb0cfa8662f056e561e243880a3be867c786 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 21:37:32 -0700 Subject: [PATCH 05/26] batched_utils -> masked_kernels Signed-off-by: Varun Sundar Rabindranath --- .../kernels/moe/{test_batched_utils.py => test_masked_kernels.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/kernels/moe/{test_batched_utils.py => test_masked_kernels.py} (100%) diff --git a/tests/kernels/moe/test_batched_utils.py b/tests/kernels/moe/test_masked_kernels.py similarity index 100% rename from tests/kernels/moe/test_batched_utils.py rename to tests/kernels/moe/test_masked_kernels.py From b82cbe5f12726cb4a15dd5cb8ffe5c0e4dd9a652 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 21:42:43 -0700 Subject: [PATCH 06/26] batched -> masked Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_masked_kernels.py | 6 ++--- .../layers/fused_moe/masked_kernels.py | 24 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py index 3c3328ab14e2..c3209e726bd8 100644 --- a/tests/kernels/moe/test_masked_kernels.py +++ b/tests/kernels/moe/test_masked_kernels.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Test for batched utility kernels. +Test for masked utility kernels. """ -## Tests for batched silu_and_mul #### -def test_batched_silu_mul(): +## Tests for masked silu_and_mul #### +def test_masked_silu_mul(): pass diff --git a/vllm/model_executor/layers/fused_moe/masked_kernels.py b/vllm/model_executor/layers/fused_moe/masked_kernels.py index a2678878bd5f..e4b703b76037 100644 --- a/vllm/model_executor/layers/fused_moe/masked_kernels.py +++ b/vllm/model_executor/layers/fused_moe/masked_kernels.py @@ -17,11 +17,11 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -## Batched Per Token Quant #### +## Masked Per Token Quant #### @triton.jit -def _batched_per_token_group_quant_fp8( +def _masked_per_token_group_quant_fp8( valid_tokens_array, # Batch dimension strides stride_yb, @@ -80,7 +80,7 @@ def _batched_per_token_group_quant_fp8( @triton.jit -def _batched_per_token_group_quant_fp8_colmajor( +def _masked_per_token_group_quant_fp8_colmajor( valid_tokens_array, # Batch strides stride_yb, @@ -142,7 +142,7 @@ def _batched_per_token_group_quant_fp8_colmajor( BLOCK) -def batched_per_token_group_quant_fp8( +def masked_per_token_group_quant_fp8( x: torch.Tensor, # [B, MAX_TOKENS, HIDDEN_SIZE] x_q: Optional[torch.Tensor], # [B, MAX_TOKENS, HIDDEN_SIZE] valid_tokens_array: torch.Tensor, # [B] @@ -180,7 +180,7 @@ def batched_per_token_group_quant_fp8( grid = (B, M) if column_major_scales: - _batched_per_token_group_quant_fp8_colmajor[grid]( + _masked_per_token_group_quant_fp8_colmajor[grid]( valid_tokens_array, x.stride(0), x_q.stride(0), @@ -200,7 +200,7 @@ def batched_per_token_group_quant_fp8( num_stages=num_stages, ) else: - _batched_per_token_group_quant_fp8[grid]( + _masked_per_token_group_quant_fp8[grid]( valid_tokens_array, x.stride(0), x_q.stride(0), @@ -266,7 +266,7 @@ def silu_and_mul( @triton.jit -def batched_silu_and_mul_kernel( +def masked_silu_and_mul_kernel( output, # [B, MAX_NUM_TOKENS, D] input, # [B, MAX_NUM_TOKENS, D * 2] valid_tokens_array, # [B] @@ -310,7 +310,7 @@ def batched_silu_and_mul_kernel( compute_type) -def invoke_batched_silu_and_mul( +def invoke_masked_silu_and_mul( output: torch.Tensor, #[B, MAX_TOKENS, D] input: torch.Tensor, #[B, MAX_TOKENS, D * 2] valid_tokens_array: torch.Tensor): @@ -328,7 +328,7 @@ def invoke_batched_silu_and_mul( grid = (batch_size, triton.cdiv(max_num_tokens, BLOCK_M), triton.cdiv(D, BLOCK_D)) - batched_silu_and_mul_kernel[grid](output, input, valid_tokens_array, - output.stride(0), output.stride(1), - input.stride(0), input.stride(1), - compute_tl_dtype, D, BLOCK_M, BLOCK_D) + masked_silu_and_mul_kernel[grid](output, input, valid_tokens_array, + output.stride(0), output.stride(1), + input.stride(0), input.stride(1), + compute_tl_dtype, D, BLOCK_M, BLOCK_D) From b2b365dc1b259a39c83f575d9a863d57cf0862e1 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 13 Jun 2025 22:17:09 -0700 Subject: [PATCH 07/26] add masked silu-mul test Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_masked_kernels.py | 50 +++++++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py index c3209e726bd8..9dea5bf4cdfd 100644 --- a/tests/kernels/moe/test_masked_kernels.py +++ b/tests/kernels/moe/test_masked_kernels.py @@ -4,7 +4,53 @@ Test for masked utility kernels. """ +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.masked_kernels import ( + invoke_masked_silu_and_mul) + + +def ref_silu_mul(x, out, valid_tokens_array): + + valid_tokens_array = valid_tokens_array.to("cpu") + batch_size = x.size(0) + for b in range(batch_size): + # num valid tokens + n = valid_tokens_array[b] + torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :]) + ## Tests for masked silu_and_mul #### -def test_masked_silu_mul(): - pass +@pytest.mark.parametrize("batch_size", [1, 13, 26, 32, 64]) +@pytest.mark.parametrize("num_tokens", [32, 64]) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, + dtype: torch.dtype): + + input = torch.randn((batch_size, num_tokens, hidden_size), + device="cuda", + dtype=dtype) + + out = torch.empty((batch_size, num_tokens, hidden_size // 2), + device="cuda", + dtype=dtype) + + ref_out = torch.empty_like(out) + ref_out.copy_(out) + + # valid num_tokens per batch + valid_num_tokens = torch.randint(low=0, + high=batch_size + 1, + size=(batch_size, ), + device="cuda") + + # reference + ref_silu_mul(input, ref_out, valid_num_tokens) + + # impl + invoke_masked_silu_and_mul(out, input, valid_num_tokens) + + torch.testing.assert_close(ref_out, out, atol=1e-3, rtol=1e-2) From 57dc316bb2e03aadc52ee75c9aac2df37d55e5f2 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 14:44:31 -0700 Subject: [PATCH 08/26] fixes and add batched per-token-quant tests Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_masked_kernels.py | 111 ++++++++++++++++-- .../layers/fused_moe/masked_kernels.py | 22 ++-- .../layers/quantization/utils/fp8_utils.py | 2 + 3 files changed, 118 insertions(+), 17 deletions(-) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py index 9dea5bf4cdfd..e17a0be70e34 100644 --- a/tests/kernels/moe/test_masked_kernels.py +++ b/tests/kernels/moe/test_masked_kernels.py @@ -8,7 +8,10 @@ import torch from vllm.model_executor.layers.fused_moe.masked_kernels import ( - invoke_masked_silu_and_mul) + invoke_masked_silu_and_mul, masked_per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.platforms import current_platform def ref_silu_mul(x, out, valid_tokens_array): @@ -18,21 +21,113 @@ def ref_silu_mul(x, out, valid_tokens_array): for b in range(batch_size): # num valid tokens n = valid_tokens_array[b] + if n == 0: + continue torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :]) +def ref_per_token_group_quant( + x: torch.Tensor, x_q: torch.Tensor, valid_tokens_array: torch.Tensor, + group_size: int, + column_major_scales: bool) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape == x_q.shape + + # make scales tensor + B, NUM_TOKENS, HIDDEN_SIZE = x.shape + x_q_s = torch.empty((B, NUM_TOKENS, HIDDEN_SIZE // group_size), + device="cuda", + dtype=torch.float32) + + valid_tokens_array = valid_tokens_array.to("cpu") + batch_size = x.size(0) + for b in range(batch_size): + # num valid tokens + n = valid_tokens_array[b] + if n == 0: + continue + x_slice = x[b, :n, :] + xq_slice, xqs_slice = per_token_group_quant_fp8( + x_slice, group_size, column_major_scales=column_major_scales) + x_q[b, :n, :].copy_(xq_slice) + x_q_s[b, :n, :].copy_(xqs_slice) + + return x_q, x_q_s + + +BATCH_SIZES = [1, 13, 26, 32] +NUM_TOKENS = [32, 64, 4096] +HIDDEN_SIZES = [1024] + +## Tests for masked per_token_group_quant_fp8 #### + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("column_major_scales", [True]) +def test_masked_per_token_group_quant_fp8(batch_size: int, num_tokens: int, + hidden_size: int, dtype: torch.dtype, + column_major_scales: bool): + + DEEPGEMM_BLOCK_SIZE = 128 + + input = torch.randn( + (batch_size, num_tokens, hidden_size), device="cuda", + dtype=dtype) / 10.0 + + out_q = torch.randn((batch_size, num_tokens, hidden_size), device="cuda") + out_q = out_q.to(dtype=current_platform.fp8_dtype()) + + ref_out_q = torch.empty_like(out_q) + ref_out_q.copy_(out_q) + + # valid num_tokens per batch + valid_num_tokens = torch.randint(low=0, + high=num_tokens + 1, + size=(batch_size, ), + device="cuda") + + # Reference + ref_out_q, ref_out_scales = ref_per_token_group_quant( + x=input, + x_q=ref_out_q, + valid_tokens_array=valid_num_tokens, + group_size=DEEPGEMM_BLOCK_SIZE, + column_major_scales=column_major_scales) + + # Impl + out_q, out_scales = masked_per_token_group_quant_fp8( + x=input, + x_q=out_q, + valid_tokens_array=valid_num_tokens, + group_size=DEEPGEMM_BLOCK_SIZE, + column_major_scales=column_major_scales) + + torch.testing.assert_close(ref_out_q, out_q) + + valid_num_tokens_cpu = valid_num_tokens.to(device="cpu") + for b in range(valid_num_tokens_cpu.size(0)): + n = valid_num_tokens_cpu[b] + torch.testing.assert_close(ref_out_scales[b, :n, :], + out_scales[b, :n, :]) + + ## Tests for masked silu_and_mul #### -@pytest.mark.parametrize("batch_size", [1, 13, 26, 32, 64]) -@pytest.mark.parametrize("num_tokens", [32, 64]) -@pytest.mark.parametrize("hidden_size", [512]) + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, dtype: torch.dtype): - input = torch.randn((batch_size, num_tokens, hidden_size), - device="cuda", - dtype=dtype) + input = torch.randn( + (batch_size, num_tokens, hidden_size), device="cuda", + dtype=dtype) / 10.0 out = torch.empty((batch_size, num_tokens, hidden_size // 2), device="cuda", @@ -43,7 +138,7 @@ def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, # valid num_tokens per batch valid_num_tokens = torch.randint(low=0, - high=batch_size + 1, + high=num_tokens + 1, size=(batch_size, ), device="cuda") diff --git a/vllm/model_executor/layers/fused_moe/masked_kernels.py b/vllm/model_executor/layers/fused_moe/masked_kernels.py index e4b703b76037..97ccbec3e483 100644 --- a/vllm/model_executor/layers/fused_moe/masked_kernels.py +++ b/vllm/model_executor/layers/fused_moe/masked_kernels.py @@ -110,9 +110,9 @@ def _masked_per_token_group_quant_fp8_colmajor( # early exit return + group_id = tl.program_id(axis=1) groups_per_row = y_num_columns // group_size valid_num_groups = num_tokens * groups_per_row - group_id = tl.program_id(axis=1) if group_id >= valid_num_groups: # early exit return @@ -150,6 +150,7 @@ def masked_per_token_group_quant_fp8( column_major_scales: bool, eps: float = 1e-10) -> tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 3 assert (x.size(-1) % group_size == 0), ( f"the last dimension of `x` {x.size(-1)} must be divisible " f"by `group_size` {group_size}") @@ -166,9 +167,12 @@ def masked_per_token_group_quant_fp8( B, MAX_TOKENS, HIDDEN_SIZE = x.shape shape = (B, MAX_TOKENS, HIDDEN_SIZE // group_size) - x_s = torch.empty(shape, device=x.device, dtype=torch.float32) if column_major_scales: - x_s = x_s.permute(-1, -2) + cms_shape = (shape[0], shape[2], shape[1]) + x_s = torch.empty(cms_shape, device=x.device, dtype=torch.float32) + x_s = x_s.permute(0, 2, 1) + else: + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) M = (MAX_TOKENS * HIDDEN_SIZE) // group_size N = group_size @@ -189,9 +193,9 @@ def masked_per_token_group_quant_fp8( x_q, x_s, group_size, - x.shape[1], - x.stride(0), - x_s.stride(1), + x.size(2), # num_columns + x.stride(1), # row_stride + x_s.stride(2), # col_stride eps, fp8_min=fp8_min, fp8_max=fp8_max, @@ -209,8 +213,8 @@ def masked_per_token_group_quant_fp8( x_q, x_s, group_size, - x.shape[1], - x.stride(0), + x.size(2), # num_columns + x.stride(1), # row_stride eps, fp8_min=fp8_min, fp8_max=fp8_max, @@ -314,7 +318,7 @@ def invoke_masked_silu_and_mul( output: torch.Tensor, #[B, MAX_TOKENS, D] input: torch.Tensor, #[B, MAX_TOKENS, D * 2] valid_tokens_array: torch.Tensor): - + assert input.ndim == 3 batch_size, max_num_tokens, D = output.size() BLOCK_D = 1024 diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 27b065911c08..798562ba0741 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -295,6 +295,7 @@ def _do_per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +@triton.jit def _per_token_group_quant_fp8( # Pointers to inputs and output y_ptr, @@ -395,6 +396,7 @@ def _do_per_token_group_quant_fp8_colmajor( tl.store(y_s_ptr, y_s) +@triton.jit def _per_token_group_quant_fp8_colmajor( # Pointers to inputs and output y_ptr, From dcace534bb72411c2e0a30b169f0b9cbe423fe24 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 14:48:10 -0700 Subject: [PATCH 09/26] relax silu mul tolerance Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_masked_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py index e17a0be70e34..91079fec46e8 100644 --- a/tests/kernels/moe/test_masked_kernels.py +++ b/tests/kernels/moe/test_masked_kernels.py @@ -148,4 +148,4 @@ def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, # impl invoke_masked_silu_and_mul(out, input, valid_num_tokens) - torch.testing.assert_close(ref_out, out, atol=1e-3, rtol=1e-2) + torch.testing.assert_close(ref_out, out) From 7ba83357b4fa92f464edcc682258b31fa066cca1 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 15:11:08 -0700 Subject: [PATCH 10/26] plugin masked kernels Signed-off-by: Varun Sundar Rabindranath --- .../layers/fused_moe/batched_deep_gemm_moe.py | 25 ++++++++----------- .../layers/fused_moe/fused_batched_moe.py | 6 ++--- .../layers/fused_moe/modular_kernel.py | 6 ++++- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 5492399efdf8..aa04b9131e63 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -6,8 +6,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.masked_kernels import ( + masked_per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache logger = init_logger(__name__) @@ -109,19 +110,15 @@ def apply( masked_m=expert_num_tokens, expected_m=expected_m) - # TODO (varun) [Optimization]: Use a batched version of activation. - # Similarly for the quant below. - self.activation(activation, workspace2, workspace1.view(-1, N)) + self.masked_activation(activation, workspace2, workspace1, + expert_num_tokens) - w2_hidden_size = workspace2.size(-1) - workspace2 = workspace2.view(-1, w2_hidden_size) - - a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(workspace2, - self.block_shape[1], - column_major_scales=False) - a2q = a2q.view(E, max_num_tokens, -1) - a2q_scale = a2q_scale.view(E, max_num_tokens, -1) + # TODO (varun) : Pass in an output tensor derived from workspace + # as a memory optimization. + a2q, a2q_scale = masked_per_token_group_quant_fp8( + x=workspace2, + group_size=self.block_shape[1], + column_major_scales=False) dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 3bbae4e57ba3..9790fd8ec9da 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -731,10 +731,8 @@ def apply( config=config, block_shape=self.block_shape) - # TODO: would be nice to use expert_num_tokens here to reduce - # garbage compute - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + self.masked_activation(activation, intermediate_cache2, + intermediate_cache1, expert_num_tokens) ic2_hidden_size = intermediate_cache2.size(-1) intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 0a7437ffaf54..47e16afd2a98 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -7,6 +7,8 @@ import torch import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.masked_kernels import ( + invoke_masked_silu_and_mul) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.utils import cdiv @@ -232,7 +234,9 @@ def masked_activation(self, activation: str, output: torch.Tensor, f"expert_num_tokens.size(0)({expert_num_tokens.size(0)}) != E({E})" ) assert activation == "silu", "Only silu_and_mul is supported for now." - pass + invoke_masked_silu_and_mul(output=output, + input=input, + valid_tokens_array=expert_num_tokens) def activation(self, activation: str, output: torch.Tensor, input: torch.Tensor) -> None: From 06d28b2b8cfa4649238a74b67850d026861cbf3e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 17:51:43 -0700 Subject: [PATCH 11/26] fix D blocking Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/masked_kernels.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/masked_kernels.py b/vllm/model_executor/layers/fused_moe/masked_kernels.py index 97ccbec3e483..35343da1c925 100644 --- a/vllm/model_executor/layers/fused_moe/masked_kernels.py +++ b/vllm/model_executor/layers/fused_moe/masked_kernels.py @@ -247,19 +247,21 @@ def silu_and_mul( BLOCK_D: tl.constexpr, compute_type: tl.constexpr): + remaining_d = D - (pid_d * BLOCK_D) + offs_m = tl.arange(0, BLOCK_M)[:, None] mask_m = offs_m < M offs_d = tl.arange(0, BLOCK_D) - mask_d = offs_d < D + mask_d = offs_d < remaining_d input_ptrs = input + offs_m * stride_im + pid_d * BLOCK_D + offs_d output_ptrs = output + offs_m * stride_om + pid_d * BLOCK_D + offs_d mask_tile = mask_m & mask_d - x_tile = tl.load(input_ptrs, mask=mask_tile, other=0.0).to(dtype=tl.float32) + y_tile = tl.load(input_ptrs + D, mask=mask_tile, other=0.0) # silu and mul From 8f9cb3d7d2c11eca256fa9d4bd15d33ed8a3b937 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 18:02:31 -0700 Subject: [PATCH 12/26] better testing Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_masked_kernels.py | 7 +++++-- vllm/model_executor/layers/fused_moe/modular_kernel.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py index 91079fec46e8..a73622e11596 100644 --- a/tests/kernels/moe/test_masked_kernels.py +++ b/tests/kernels/moe/test_masked_kernels.py @@ -55,11 +55,12 @@ def ref_per_token_group_quant( BATCH_SIZES = [1, 13, 26, 32] -NUM_TOKENS = [32, 64, 4096] -HIDDEN_SIZES = [1024] +NUM_TOKENS = [7, 37, 64, 4096] ## Tests for masked per_token_group_quant_fp8 #### +HIDDEN_SIZES = [128, 256, 384, 512, 1024] + @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -116,6 +117,8 @@ def test_masked_per_token_group_quant_fp8(batch_size: int, num_tokens: int, ## Tests for masked silu_and_mul #### +HIDDEN_SIZES = [124, 1024, 2176, 2816] + @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 47e16afd2a98..d802c6f69348 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -228,7 +228,8 @@ def masked_activation(self, activation: str, output: torch.Tensor, Note that expert_num_tokens[i] is the number of tokens that are actually valid for expert i. """ - assert output.shape == input.shape + assert output.ndim == 3 and input.ndim == 3 + assert output.size(-1) * 2 == input.size(-1) E, _, _ = input.shape assert expert_num_tokens.size(0) == E, ( f"expert_num_tokens.size(0)({expert_num_tokens.size(0)}) != E({E})" From c98c2e22641da336bbc06dbcf34f35faad28ed96 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 18:35:29 -0700 Subject: [PATCH 13/26] make out_q optional Signed-off-by: Varun Sundar Rabindranath --- .../layers/fused_moe/masked_kernels.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/masked_kernels.py b/vllm/model_executor/layers/fused_moe/masked_kernels.py index 35343da1c925..68f28b11fa1f 100644 --- a/vllm/model_executor/layers/fused_moe/masked_kernels.py +++ b/vllm/model_executor/layers/fused_moe/masked_kernels.py @@ -143,12 +143,13 @@ def _masked_per_token_group_quant_fp8_colmajor( def masked_per_token_group_quant_fp8( - x: torch.Tensor, # [B, MAX_TOKENS, HIDDEN_SIZE] - x_q: Optional[torch.Tensor], # [B, MAX_TOKENS, HIDDEN_SIZE] - valid_tokens_array: torch.Tensor, # [B] - group_size: int, - column_major_scales: bool, - eps: float = 1e-10) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, # [B, MAX_TOKENS, HIDDEN_SIZE] + valid_tokens_array: torch.Tensor, # [B] + group_size: int, + column_major_scales: bool, + x_q: Optional[torch.Tensor] = None, # [B, MAX_TOKENS, HIDDEN_SIZE] + eps: float = 1e-10 +) -> tuple[torch.Tensor, torch.Tensor]: assert x.ndim == 3 assert (x.size(-1) % group_size == 0), ( From c20487e7e82d9a073c3e8a544e1879d4b7be43a9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 14 Jun 2025 18:41:45 -0700 Subject: [PATCH 14/26] fixes Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index aa04b9131e63..9d0069f740a7 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -117,6 +117,7 @@ def apply( # as a memory optimization. a2q, a2q_scale = masked_per_token_group_quant_fp8( x=workspace2, + valid_tokens_array=expert_num_tokens, group_size=self.block_shape[1], column_major_scales=False) From 2fb3d5f8810942f08a8b303daee3e01703f3afb8 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 16 Jun 2025 19:49:48 -0700 Subject: [PATCH 15/26] add batched cuda silu and mul Signed-off-by: Varun Sundar Rabindranath --- csrc/activation_kernels.cu | 75 ++++++++++++++++++++++++++++++++++++-- csrc/ops.h | 3 ++ csrc/torch_bindings.cpp | 5 +++ 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 55e659679701..e2980cf70e5d 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -15,14 +15,13 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, const scalar_t& y) { return act_first ? ACT_FN(x) * y : x * ACT_FN(y); } -// Activation and gating kernel template. template -__global__ void act_and_mul_kernel( +__device__ void _act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + const int d, const int64_t token_idx) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -31,6 +30,18 @@ __global__ void act_and_mul_kernel( } } +// Activation and gating kernel template. + +template +__global__ void act_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + _act_and_mul_kernel(out, input, d, token_idx) +} + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -223,3 +234,61 @@ void gelu_quick(torch::Tensor& out, // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); } + +namespace vllm { +// Batched act_and_mul kernel template +template +__global__ void batched_act_and_mul_kernel( + scalar_t* out, // [B, max_tokens, d] + const scalar_t* input, // [B, max_tokens, 2, d] + const int64_t* valid_tokens_array, // [B] + const int d) { + ; + const int64_t batch_idx = blockIdx.x; + const int64_t num_tokens = valid_tokens_array[batch_idx]; + if (num_tokens == 0) { + return; + } + + const int64_t token_idx = blockIdx.y; + if (token_idx >= num_tokens) { + return; + } + + const int64_t max_num_tokens = gridDim.y; + scalar_t* __restrict__ batch_out = &out[batch_idx * max_num_tokens * d]; + const scalar_t* __restrict__ batch_input = + &input[batch_idx * max_num_tokens * d * 2]; + _act_and_mul_kernel(batch_out, batch_input, d, token_idx) +} +} // namespace vllm + +// Launch batched activation and gating kernel. +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ + int64_t const batch_size = input.size(0); \ + int64_t const max_num_tokens = input.size(1); \ + int const d = input.size(2) / 2; \ + dim3 grid(batch_size, max_num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), \ + "batched_act_and_mul_kernel", [&] { \ + vllm::batched_act_and_mul_kernel, ACT_FIRST> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + valid_tokens_array.data_ptr ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); + ops.def( + "batched_silu_and_mul(Tensor! result, Tensor input, Tensor " + "valid_tokens_array) -> ()"); + ops.impl("batched_silu_and_mul", torch::kCUDA, &batched_silu_and_mul); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); From 97bda02babfb8cc6828e5b3d69ab1e94d30a803e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 16 Jun 2025 20:20:29 -0700 Subject: [PATCH 16/26] fixes Signed-off-by: Varun Sundar Rabindranath --- csrc/activation_kernels.cu | 12 ++++++------ csrc/ops.h | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index e2980cf70e5d..999f52f50999 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -22,7 +22,6 @@ __device__ void _act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d, const int64_t token_idx) { - const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); @@ -39,7 +38,7 @@ __global__ void act_and_mul_kernel( const scalar_t* __restrict__ input, // [..., 2, d] const int d) { const int64_t token_idx = blockIdx.x; - _act_and_mul_kernel(out, input, d, token_idx) + _act_and_mul_kernel(out, input, d, token_idx); } template @@ -242,7 +241,7 @@ template (batch_out, batch_input, d, + token_idx); } } // namespace vllm @@ -280,7 +280,7 @@ __global__ void batched_act_and_mul_kernel( vllm::batched_act_and_mul_kernel, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), \ - valid_tokens_array.data_ptr Date: Mon, 16 Jun 2025 20:21:48 -0700 Subject: [PATCH 17/26] fixes Signed-off-by: Varun Sundar Rabindranath --- csrc/activation_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 999f52f50999..b5f86721ea35 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -290,5 +290,5 @@ void batched_silu_and_mul(torch::Tensor& out, // [..., d] { TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(valid_tokens_array.dtype() == torch::kInt32); - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); + LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); } \ No newline at end of file From fc5bc04a9bddd63dde0a8a74c67db3b03a52baff Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 16 Jun 2025 20:56:57 -0700 Subject: [PATCH 18/26] add batched impl tests Signed-off-by: Varun Sundar Rabindranath --- csrc/activation_kernels.cu | 40 +++++++++--------- tests/kernels/core/test_activation.py | 54 ++++++++++++++++++++++++ tests/kernels/moe/test_masked_kernels.py | 4 +- 3 files changed, 76 insertions(+), 22 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index b5f86721ea35..a2083f884c1e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -267,26 +267,26 @@ __global__ void batched_act_and_mul_kernel( // Launch batched activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ - int64_t const batch_size = input.size(0); \ - int64_t const max_num_tokens = input.size(1); \ - int const d = input.size(2) / 2; \ - dim3 grid(batch_size, max_num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), \ - "batched_act_and_mul_kernel", [&] { \ - vllm::batched_act_and_mul_kernel, ACT_FIRST> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), \ - valid_tokens_array.data_ptr, \ + ACT_FIRST> \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + valid_tokens_array.data_ptr(), d); \ + }); + +void batched_silu_and_mul(torch::Tensor& out, // [B, max_tokens, d] + torch::Tensor& input, // [B, max_tokens, 2, d] + torch::Tensor& valid_tokens_array) // [B] { TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(valid_tokens_array.dtype() == torch::kInt32); diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 29c5e70a8ba8..d4c68a554697 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -14,6 +14,21 @@ SiluAndMul) from vllm.platforms import current_platform + +def ref_batched_silu_mul(x, out, valid_tokens_array): + """ + Reference implementation of batched silu_and_mul + """ + valid_tokens_array = valid_tokens_array.to("cpu") + batch_size = x.size(0) + for b in range(batch_size): + # num valid tokens + n = valid_tokens_array[b] + if n == 0: + continue + torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :]) + + DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing @@ -106,3 +121,42 @@ def test_activation( out = torch.empty_like(x) opcheck(fn, (out, x)) + + +## Test Batched Implementaion #### + +BATCH_SIZES = [1, 13, 26, 32] +NUM_TOKENS = [7, 37, 64, 4096] +D = [128, 256, 384, 512, 1024, 13824] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_batched_silu_mul(batch_size: int, num_tokens: int, d: int, + dtype: torch.dtype): + + input = torch.randn( + (batch_size, num_tokens, d), device="cuda", dtype=dtype) / 10.0 + + out = torch.empty((batch_size, num_tokens, d // 2), + device="cuda", + dtype=dtype) + + ref_out = out.clone() + + # valid num_tokens per batch + valid_num_tokens = torch.randint(low=0, + high=num_tokens + 1, + size=(batch_size, ), + device="cuda").to(dtype=torch.int32) + + # reference + ref_batched_silu_mul(input, ref_out, valid_num_tokens) + + # impl + torch.ops._C.batched_silu_and_mul(out, input, valid_num_tokens) + + torch.testing.assert_close(ref_out, out) diff --git a/tests/kernels/moe/test_masked_kernels.py b/tests/kernels/moe/test_masked_kernels.py index a73622e11596..4dbff8f185a8 100644 --- a/tests/kernels/moe/test_masked_kernels.py +++ b/tests/kernels/moe/test_masked_kernels.py @@ -88,7 +88,7 @@ def test_masked_per_token_group_quant_fp8(batch_size: int, num_tokens: int, valid_num_tokens = torch.randint(low=0, high=num_tokens + 1, size=(batch_size, ), - device="cuda") + device="cuda").to(torch.int32) # Reference ref_out_q, ref_out_scales = ref_per_token_group_quant( @@ -143,7 +143,7 @@ def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, valid_num_tokens = torch.randint(low=0, high=num_tokens + 1, size=(batch_size, ), - device="cuda") + device="cuda").to(torch.int32) # reference ref_silu_mul(input, ref_out, valid_num_tokens) From 67e76b52b3e7da226b75eee0390f481343ca00ab Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 16 Jun 2025 21:05:05 -0700 Subject: [PATCH 19/26] use cuda silu mul Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index d802c6f69348..5d88ada34d91 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -7,8 +7,6 @@ import torch import vllm.envs as envs -from vllm.model_executor.layers.fused_moe.masked_kernels import ( - invoke_masked_silu_and_mul) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.utils import cdiv @@ -234,10 +232,9 @@ def masked_activation(self, activation: str, output: torch.Tensor, assert expert_num_tokens.size(0) == E, ( f"expert_num_tokens.size(0)({expert_num_tokens.size(0)}) != E({E})" ) - assert activation == "silu", "Only silu_and_mul is supported for now." - invoke_masked_silu_and_mul(output=output, - input=input, - valid_tokens_array=expert_num_tokens) + if activation != "silu": + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + torch.ops._C.batched_silu_and_mul(output, input, expert_num_tokens) def activation(self, activation: str, output: torch.Tensor, input: torch.Tensor) -> None: From 8de2fd39fcb60f1a9cb84a34c5245b2b991561fe Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 18 Jun 2025 07:32:15 -0700 Subject: [PATCH 20/26] deep_ep + use_fp8_dispatch Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/layer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f2175886..c6c908f73a25 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -45,7 +45,8 @@ from .pplx_prepare_finalize import PplxPrepareAndFinalize if has_deepep: from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, + DeepEPLLPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -377,6 +378,12 @@ def init_prepare_finalize(self, moe: MoEConfig, all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) + # Note : We may want to use FP8 dispatch even otherwise just to + # reduce datamovement + use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() + and act_quant_block_size + == DEEPEP_QUANT_BLOCK_SIZE) + # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( @@ -386,7 +393,7 @@ def init_prepare_finalize(self, moe: MoEConfig, max_tokens_per_rank=moe.max_num_tokens, quant_dtype=quant_dtype, block_shape=act_quant_block_size, - use_fp8_dispatch=False, + use_fp8_dispatch=use_fp8_dispatch, ) self.topk_indices_dtype = None From b2178bed14746258c2624c53cc9a60abcb034c95 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Jun 2025 10:39:49 -0400 Subject: [PATCH 21/26] Quantize kernel with the layout that deepgemm wants Signed-off-by: Tyler Michael Smith --- .../layers/fused_moe/batched_deep_gemm_moe.py | 175 +++++++++++++++++- 1 file changed, 168 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 9d0069f740a7..47fb2d968d19 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -2,6 +2,8 @@ import importlib.util from typing import Optional +import triton +import triton.language as tl import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -15,6 +17,166 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +@triton.jit +def _per_token_group_quant_fp8_3d( + # Pointers ------------------------------------------------------------ + y_ptr, # FP16 activations (E, T, H) + y_q_ptr, # FP8 quantized activations (E, T, H) + + y_s_ptr, # FP32 scales (E, T, G) + counts_ptr, # INT32 number of tokens per expert (E) + + # Sizes --------------------------------------------------------------- + E: tl.constexpr, # num_experts + T: tl.constexpr, # max_num_tokens + H: tl.constexpr, # hidden dimension + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + + # Strides for y (elements) ------------------------------------------- + stride_y_e, + stride_y_t, + stride_y_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, +): + """Dynamic FP8 quantisation over a 3‑D tensor laid out **(E, T, H)**. + + * Each program instance handles **one** `GROUP_SIZE`‑length slice along H + for a single (expert *e*, token *t*). + * Scales are produced with shape **(E, T, G)** where + `G = H // GROUP_SIZE` and with *element* strides + `(T*G, 1, T)` so that the *token* dimension is the fastest‑varying in + memory – matching the downstream reshape you showed. + * All strides are expressed **in elements**, not bytes. + """ + + G = H // GROUP_SIZE # groups per hidden dim + + # ----------------------- map program id -> (e, g) -------------------- + pid = tl.program_id(0) + e = pid // G + g = pid % G + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int32) + + # block for H dimension + cols = tl.arange(0, BLOCK) + mask_h = cols < BLOCK + + # iterate over tokens for this (expert, group) + t = tl.zeros([], tl.int32) + while t < n_tokens: + base_y_offset = e * stride_y_e + t * stride_y_t + g * GROUP_SIZE * stride_y_h + base_yq_offset = e * stride_yq_e + t * stride_yq_t + g * GROUP_SIZE * stride_yq_h + base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g + + mask = mask_h + y = tl.load(y_ptr + base_y_offset + cols * stride_y_h, + mask=mask, other=0.0).to(tl.float32) + + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, + y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset, y_s) + + t += 1 + + +def quant_fp8_3d( + y: torch.Tensor, # (E, T, H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + fp8_dtype = torch.float8_e4m3fn, + eps: float = 1e-6, +): + """Quantize y into FP8 with per‑(expert, token, group) scales. + + Only the first `tokens_per_expert[e]` tokens are quantized per expert; + the remaining positions in each (E, T, H) slice are treated as padding. + + Returns `(y_q, y_s)` where + * `y_q` is the FP8 tensor, same shape and **standard PyTorch order** as *y*. + * `y_s` has shape `(E, T, H // group_size)` and element strides + `(T * G, 1, T)` so that the *token* dimension is contiguous. + """ + + assert y.ndim == 3, "y must be (E, T, H)" + E, T, H = y.shape + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ + "tokens_per_expert must be shape (E,)" + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # ---------------- allocate outputs ---------------------------------- + y_q = torch.empty_like(y, dtype=fp8_dtype) + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + + # allocate scale buffer with proper shape and stride + y_s = torch.empty_strided((E, T, G), (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, device=y.device) + + # ---------------- stride bookkeeping (elements, not bytes) ---------- + stride_y_e, stride_y_t, stride_y_h = y.stride() + + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + + # stride for tokens_per_expert (elements) + stride_cnt_e = tokens_per_expert.stride()[0] + + # static grid over experts and H-groups; tokens loop is internal to the kernel + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = -f_info.max + + _per_token_group_quant_fp8_3d[grid]( + y, y_q, y_s, tokens_per_expert, + E, T, H, group_size, + stride_y_e, stride_y_t, stride_y_h, + stride_yq_e, stride_yq_t, stride_yq_h, + stride_ys_e, stride_ys_t, stride_ys_g, + stride_cnt_e, + eps, fp8_min, fp8_max, + BLOCK=group_size, + num_warps=4, + ) + + return y_q, y_s + class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 @@ -87,6 +249,9 @@ def apply( ): import deep_gemm as dg assert hidden_states.ndim == 3 + assert(w1_zp is None and w2_zp is None) + assert(a2_scale is None) + assert(expert_map is None) a1q = hidden_states _, N, K = w1.size() @@ -113,13 +278,9 @@ def apply( self.masked_activation(activation, workspace2, workspace1, expert_num_tokens) - # TODO (varun) : Pass in an output tensor derived from workspace - # as a memory optimization. - a2q, a2q_scale = masked_per_token_group_quant_fp8( - x=workspace2, - valid_tokens_array=expert_num_tokens, - group_size=self.block_shape[1], - column_major_scales=False) + a2q, a2q_scale = quant_fp8_3d(workspace2, + tokens_per_expert=expert_num_tokens, + group_size=self.block_shape[1]) dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), From a2b4f8ec3b9222eca1b10eba75d3b5d20dbff0a6 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Jun 2025 10:45:29 -0400 Subject: [PATCH 22/26] rm bad assert Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 47fb2d968d19..93a211cad416 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -251,7 +251,6 @@ def apply( assert hidden_states.ndim == 3 assert(w1_zp is None and w2_zp is None) assert(a2_scale is None) - assert(expert_map is None) a1q = hidden_states _, N, K = w1.size() From 041a9e968ac4201073bcd439047a1031211a4c58 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 18 Jun 2025 11:15:48 -0700 Subject: [PATCH 23/26] fixes - use-fp8-dispatch Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c6c908f73a25..98733f101acb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -381,7 +381,7 @@ def init_prepare_finalize(self, moe: MoEConfig, # Note : We may want to use FP8 dispatch even otherwise just to # reduce datamovement use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() - and act_quant_block_size + and act_quant_block_size[1] == DEEPEP_QUANT_BLOCK_SIZE) # Note (varun): Whether to use FP8 dispatch or not needs some From 2a7e5376de0cb7c0dc3d2341bb38646cefd98cd9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 18 Jun 2025 12:07:55 -0700 Subject: [PATCH 24/26] fix topk ids Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 2d7cf39a8cca..77820d12a6f3 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -169,11 +169,11 @@ def make(config: TestConfig, rank) -> "TestTensors": block_k = block_size[1] _, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k) - topk_ids = torch.randint( - low=0, - high=config.num_experts, - size=(m, topk), - device=torch.cuda.current_device()).to(dtype=torch.int64) + # distribute topk_ids evenly + topk_ids = torch.empty((m, topk), device="cpu", dtype=torch.int64) + for mi in range(m): + topk_ids[mi] = torch.randperm(config.num_experts)[:topk] + topk_ids = topk_ids.to(device=torch.cuda.current_device()) topk_weights = torch.randn(topk_ids.shape, dtype=torch.float32, @@ -459,17 +459,23 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, w2, w1_scale, w2_scale) +TOPKS = [6] MNKs = [ - (1, 128, 2560), - (2, 128, 2560), - (3, 1024, 2560), - (32, 128, 2560), - (45, 512, 2560), - (64, 1024, 2560), - (222, 1024, 2560), + #(1, 128, 2560), + #(2, 128, 2560), + #(3, 1024, 2560), + #(32, 128, 2560), + + #(45, 512, 2560), + #(64, 1024, 2560), + #(222, 1024, 2560), + + #(45, 128, 2560), + #(64, 128, 2560), + (222, 128, 2560), ] # Fix tests for USE_FP8_DISPATCH=True -USE_FP8_DISPATCH = [False] +USE_FP8_DISPATCH = [True] @pytest.mark.parametrize("mnk", MNKs) From a2eb4f94e244a456f96de7fbc7b91a6f94e45deb Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 18 Jun 2025 13:50:37 -0700 Subject: [PATCH 25/26] update batched silu mul kernel Signed-off-by: Varun Sundar Rabindranath --- csrc/activation_kernels.cu | 64 ++++++++++++++++----------- tests/kernels/core/test_activation.py | 2 +- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index a2083f884c1e..46e05f18c8f0 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -242,46 +242,56 @@ __global__ void batched_act_and_mul_kernel( scalar_t* out, // [B, max_tokens, d] const scalar_t* input, // [B, max_tokens, 2, d] const int32_t* valid_tokens_array, // [B] - const int d) { - ; + const int d, const int max_num_tokens) { const int64_t batch_idx = blockIdx.x; const int64_t num_tokens = valid_tokens_array[batch_idx]; if (num_tokens == 0) { return; } - const int64_t token_idx = blockIdx.y; - if (token_idx >= num_tokens) { - return; + int const col_offset = blockIdx.y * blockDim.x; + scalar_t* __restrict__ b_out = + &out[batch_idx * max_num_tokens * d + col_offset]; + const scalar_t* __restrict__ b_in = + &input[batch_idx * max_num_tokens * d * 2 + col_offset]; + + int token_idx = 0; + const int tidx = threadIdx.x; + while (token_idx < num_tokens) { + if (col_offset + tidx < d) { + const scalar_t x = VLLM_LDG(&b_in[tidx]); + const scalar_t y = VLLM_LDG(&b_in[tidx + d]); + b_out[tidx] = compute(x, y); + } + + b_out += d; + b_in += (2 * d); + + ++token_idx; } - - const int64_t max_num_tokens = gridDim.y; - scalar_t* __restrict__ batch_out = &out[batch_idx * max_num_tokens * d]; - const scalar_t* __restrict__ batch_input = - &input[batch_idx * max_num_tokens * d * 2]; - _act_and_mul_kernel(batch_out, batch_input, d, - token_idx); } } // namespace vllm // Launch batched activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ - int64_t const batch_size = input.size(0); \ - int64_t const max_num_tokens = input.size(1); \ - int const d = input.size(2) / 2; \ - dim3 grid(batch_size, max_num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "batched_act_and_mul_kernel", [&] { \ - vllm::batched_act_and_mul_kernel, \ - ACT_FIRST> \ - <<>>( \ - out.data_ptr(), input.data_ptr(), \ - valid_tokens_array.data_ptr(), d); \ +#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ + int64_t const batch_size = input.size(0); \ + int64_t const max_num_tokens = input.size(1); \ + int const d = input.size(2) / 2; \ + int const block_size = std::min(d, 1024); \ + int const blocks_per_row = ((d - 1) / block_size) + 1; \ + dim3 grid(batch_size, blocks_per_row); \ + dim3 block(block_size); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "batched_act_and_mul_kernel", [&] { \ + vllm::batched_act_and_mul_kernel, \ + ACT_FIRST> \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + valid_tokens_array.data_ptr(), d, max_num_tokens); \ }); void batched_silu_and_mul(torch::Tensor& out, // [B, max_tokens, d] diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index d4c68a554697..33c479cf13ab 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -127,7 +127,7 @@ def test_activation( BATCH_SIZES = [1, 13, 26, 32] NUM_TOKENS = [7, 37, 64, 4096] -D = [128, 256, 384, 512, 1024, 13824] +D = [128, 256, 384, 512, 1024, 1536, 13824] @pytest.mark.parametrize("batch_size", BATCH_SIZES) From fffaf974bdc24e06a4d20b1246eb4d0fb223afdc Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 19 Jun 2025 08:21:27 -0700 Subject: [PATCH 26/26] fix fp8 dispatch tests Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 77820d12a6f3..037815ec15ee 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -15,7 +15,8 @@ from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -169,16 +170,17 @@ def make(config: TestConfig, rank) -> "TestTensors": block_k = block_size[1] _, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k) - # distribute topk_ids evenly + score = torch.randn((m, config.num_experts), + device="cuda", + dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(rank_tokens, score, topk, False) + + # overwrite topk_ids to distribute evenly. topk_ids = torch.empty((m, topk), device="cpu", dtype=torch.int64) for mi in range(m): topk_ids[mi] = torch.randperm(config.num_experts)[:topk] topk_ids = topk_ids.to(device=torch.cuda.current_device()) - topk_weights = torch.randn(topk_ids.shape, - dtype=torch.float32, - device=torch.cuda.current_device()) - return TestTensors(rank_tokens=rank_tokens, rank_token_scales=rank_token_scales, topk=topk_ids, @@ -459,23 +461,25 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, w2, w1_scale, w2_scale) -TOPKS = [6] +TOPKS = [2, 6] MNKs = [ - #(1, 128, 2560), - #(2, 128, 2560), - #(3, 1024, 2560), - #(32, 128, 2560), - - #(45, 512, 2560), - #(64, 1024, 2560), - #(222, 1024, 2560), - - #(45, 128, 2560), - #(64, 128, 2560), + (1, 128, 2560), + (2, 128, 2560), + (3, 1024, 2560), + (32, 128, 2560), + (45, 512, 2560), + (64, 1024, 2560), + (222, 1024, 2560), + (45, 128, 2560), + (64, 128, 2560), (222, 128, 2560), + (45, 2048, 2560), + (64, 2048, 2560), + (222, 2048, 2560), + (333, 2048, 2560), + (444, 2048, 2560), ] -# Fix tests for USE_FP8_DISPATCH=True -USE_FP8_DISPATCH = [True] +USE_FP8_DISPATCH = [False, True] @pytest.mark.parametrize("mnk", MNKs)