-
Notifications
You must be signed in to change notification settings - Fork 743
[Feature] Add TritonBF16MoEMethod for BF16 MoE inference #7734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,10 +18,23 @@ | |
|
|
||
| import paddle | ||
| import paddle.nn.functional as F | ||
| import triton.language as tl | ||
| from paddle import nn | ||
|
|
||
| import fastdeploy | ||
| from fastdeploy.model_executor.layers.moe.moe import get_moe_scores | ||
| from fastdeploy.model_executor.layers.moe.triton_moe_kernels import ( | ||
| fused_moe_kernel_bf16, | ||
| fused_moe_kernel_paddle, | ||
| ) | ||
| from fastdeploy.model_executor.layers.quantization.fp8_utils import ( | ||
| fused_stack_transpose_quant, | ||
| quant_weight_ue8m0, | ||
| transform_scale_ue8m0, | ||
| ) | ||
| from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant | ||
| from fastdeploy.model_executor.layers.utils import get_tensor | ||
| from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func | ||
| from fastdeploy.model_executor.utils import ( | ||
| TensorTracker, | ||
| free_tensor, | ||
|
|
@@ -33,20 +46,7 @@ | |
| from fastdeploy.utils import ceil_div, register_custom_python_op | ||
|
|
||
| from ..quantization.quant_base import QuantMethodBase | ||
|
|
||
| try: | ||
| from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func | ||
|
|
||
| from .triton_moe_kernels import fused_moe_kernel_paddle | ||
| except ImportError: | ||
| pass | ||
| from fastdeploy.model_executor.layers.moe.moe import get_moe_scores | ||
| from fastdeploy.model_executor.layers.quantization.fp8_utils import ( | ||
| fused_stack_transpose_quant, | ||
| quant_weight_ue8m0, | ||
| transform_scale_ue8m0, | ||
| ) | ||
| from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant | ||
| from .fused_moe_backend_base import UnquantizedFusedMoEMethod | ||
|
|
||
|
|
||
| class TritonWeightOnlyMoEMethod(QuantMethodBase): | ||
|
|
@@ -780,8 +780,8 @@ def apply( | |
| stride_am=x_q.strides[0], | ||
| stride_ak=x_q.strides[1], | ||
| stride_be=layer.up_gate_proj_weight.strides[0], | ||
| stride_bk=layer.up_gate_proj_weight.strides[2], | ||
| stride_bn=layer.up_gate_proj_weight.strides[1], | ||
| stride_bk=layer.up_gate_proj_weight.strides[1], | ||
| stride_bn=layer.up_gate_proj_weight.strides[2], | ||
| stride_cm=up_gate_proj_out.strides[0], | ||
| stride_cn=up_gate_proj_out.strides[1], | ||
| # | ||
|
|
@@ -1885,3 +1885,284 @@ def apply( | |
| fc1_latent_proj, | ||
| fc2_latent_proj, | ||
| ) | ||
|
|
||
|
|
||
| class TritonMoEMethod(UnquantizedFusedMoEMethod): | ||
| """ | ||
| Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. | ||
|
|
||
| Activated via: export FD_MOE_BACKEND=triton | ||
| Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. | ||
| This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. | ||
| """ | ||
|
|
||
| def __init__(self, quant_config=None): | ||
| super().__init__(quant_config) | ||
|
|
||
| def process_loaded_weights(self, layer: nn.Layer, state_dict): | ||
| """Stack individual expert weights into the stacked parameter.""" | ||
| up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) | ||
| layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) | ||
| layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) | ||
|
|
||
| def _get_default_config(self, M: int, N: int, K: int, num_experts: int = 64) -> dict: | ||
| """ | ||
| Heuristic tile config for BF16 MoE, aligned with vLLM's get_default_config logic. | ||
| M: number of token-expert pairs | ||
| N: output dimension of the GEMM | ||
| K: input dimension of the GEMM | ||
| num_experts: number of local experts (for GROUP_SIZE_M heuristic) | ||
| """ | ||
| if M <= 32: | ||
| block_m, block_n, block_k = 16, 64, 128 | ||
| num_warps, num_stages = 4, 4 | ||
| elif M <= 96: | ||
| block_m, block_n, block_k = 32, 64, 128 | ||
| num_warps, num_stages = 4, 3 | ||
| elif M <= 512: | ||
| block_m, block_n, block_k = 64, 128, 64 | ||
| num_warps, num_stages = 8, 3 | ||
| else: | ||
| block_m, block_n, block_k = 128, 128, 64 | ||
| num_warps, num_stages = 8, 3 | ||
|
|
||
| tokens_per_expert = M // max(num_experts, 1) | ||
| group_m = 16 if tokens_per_expert > 128 else 1 | ||
|
|
||
| return { | ||
| "BLOCK_SIZE_M": block_m, | ||
| "BLOCK_SIZE_N": block_n, | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| "BLOCK_SIZE_K": block_k, | ||
| "GROUP_SIZE_M": group_m, | ||
| "num_warps": num_warps, | ||
| "num_stages": num_stages, | ||
| } | ||
|
|
||
| def apply_tp( | ||
| self, | ||
| layer: nn.Layer, | ||
| x: paddle.Tensor, | ||
| gate: nn.Layer, | ||
| topk_ids_hookfunc: Callable = None, | ||
| fc1_latent_proj: nn.Layer = None, | ||
| fc2_latent_proj: nn.Layer = None, | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| BF16 Triton Fused MoE forward. | ||
|
|
||
| Pipeline: | ||
| 1. Gate + topk routing | ||
| 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded | ||
| 3. fused_moe_kernel_bf16 GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] | ||
| 4. SwiGLU activation | ||
| 5. fused_moe_kernel_bf16 GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] | ||
| (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) | ||
| 6. Reshape + sum over topk dim | ||
| """ | ||
| token_num = x.shape[0] | ||
| if token_num == 0: | ||
| return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) | ||
|
|
||
| top_k = layer.top_k | ||
| num_local_experts = layer.num_local_experts | ||
| moe_intermediate_size = layer.moe_intermediate_size | ||
| hidden_size = layer.hidden_size | ||
|
|
||
| # --- 1. Routing --- | ||
| gate_out = gate(x) | ||
|
|
||
| if layer.topk_method == "noaux_tc": | ||
| from fastdeploy.model_executor.layers.moe.moe import get_moe_scores | ||
|
|
||
| use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() | ||
| if not use_fused: | ||
| gate_out = gate_out.cast("float32") | ||
|
|
||
| _, topk_weights, topk_ids = get_moe_scores( | ||
| gate_out, | ||
| layer.n_group, | ||
| layer.topk_group, | ||
| top_k, | ||
| layer.routed_scaling_factor, | ||
| layer.gate_correction_bias, | ||
| getattr(layer, "renormalize", True), | ||
| use_fused_cast=use_fused, | ||
| topk_reduce_func=getattr(layer, "topk_reduce_func", None), | ||
| ) | ||
| else: | ||
| gate_out = gate_out.cast("float32") | ||
| topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( | ||
| gate_out, | ||
| layer.gate_correction_bias, | ||
| top_k, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 经核实: 建议修复(取消注释即可): # Ensure topk_ids is int64 (noaux_tc may return int32, tritonmoe_preprocess requires int64)
if topk_ids.dtype != paddle.int64:
topk_ids = topk_ids.cast("int64") |
||
| True, # apply_norm_weight | ||
| False, | ||
| ) | ||
|
|
||
| if topk_ids_hookfunc is not None: | ||
| topk_ids_hookfunc(topk_ids=topk_ids) | ||
|
|
||
| # # Ensure topk_ids is int64 (noaux_tc may return int32, tritonmoe_preprocess requires int64) | ||
| # if topk_ids.dtype != paddle.int64: | ||
| # topk_ids = topk_ids.cast("int64") | ||
|
|
||
| # --- 2. Preprocess: sort tokens by expert assignment --- | ||
| # from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func | ||
|
|
||
| num_token_expert_pairs = token_num * top_k | ||
| # Use token_num (not pairs) for config selection, matching vLLM's heuristic: | ||
| # M represents "how many unique tokens each expert sees on average", which | ||
| # determines whether the workload is memory-bound (decode) or compute-bound (prefill). | ||
| cfg = self._get_default_config(token_num, moe_intermediate_size * 2, hidden_size, num_local_experts) | ||
|
|
||
| # Use naive_block_assignment when token count is very small (decode scenario). | ||
| # Each M-block handles exactly one token-expert pair, skipping the expensive | ||
| # preprocess sort kernel. Condition mirrors vLLM: num_pairs * 4 <= num_experts. | ||
| _SPARSITY_FACTOR = 4 | ||
| use_naive = num_token_expert_pairs * _SPARSITY_FACTOR <= num_local_experts | ||
|
|
||
| if use_naive: | ||
| # Skip preprocess: use topk_ids directly as expert_ids (one per pair) | ||
| expert_ids = topk_ids.reshape([-1]).cast("int32") | ||
| num_tokens_post_padded = paddle.full([1], num_token_expert_pairs * cfg["BLOCK_SIZE_M"], dtype="int32") | ||
| max_possible_num_post_padded = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] | ||
| # sorted_token_ids is not used in naive mode; pass expert_ids as a valid ptr | ||
| sorted_token_ids = expert_ids | ||
| else: | ||
| sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( | ||
| topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] | ||
| ) | ||
| max_possible_num_post_padded = sorted_token_ids.shape[0] | ||
| # Grid clipping: avoid launching blocks that will immediately early-return | ||
| if token_num < cfg["BLOCK_SIZE_M"]: | ||
| max_possible_num_post_padded = min( | ||
| max_possible_num_post_padded, | ||
| token_num * top_k * cfg["BLOCK_SIZE_M"], | ||
| ) | ||
|
|
||
| # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- | ||
| # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn | ||
| up_gate_proj_out = paddle.empty( | ||
| [num_token_expert_pairs, moe_intermediate_size * 2], | ||
| dtype=x.dtype, | ||
| ) | ||
| grid1 = ( | ||
| ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) | ||
| * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), | ||
| ) | ||
| fused_moe_kernel_bf16[grid1]( | ||
| x, | ||
| layer.up_gate_proj_weight, | ||
| up_gate_proj_out, | ||
| None, # topk_weights_ptr (no weight mul on GEMM1) | ||
| sorted_token_ids, | ||
| expert_ids, | ||
| num_tokens_post_padded, | ||
| N=moe_intermediate_size * 2, | ||
| K=hidden_size, | ||
| EM=max_possible_num_post_padded, | ||
| num_valid_tokens=num_token_expert_pairs, | ||
| stride_am=x.strides[0], | ||
| stride_ak=x.strides[1], | ||
| stride_be=layer.up_gate_proj_weight.strides[0], | ||
| stride_bk=layer.up_gate_proj_weight.strides[1], | ||
| stride_bn=layer.up_gate_proj_weight.strides[2], | ||
| stride_cm=up_gate_proj_out.strides[0], | ||
| stride_cn=up_gate_proj_out.strides[1], | ||
| BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], | ||
| BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], | ||
| BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], | ||
| GROUP_SIZE_M=cfg["GROUP_SIZE_M"], | ||
| MUL_ROUTED_WEIGHT=False, | ||
| top_k=top_k, | ||
| compute_type=tl.bfloat16, | ||
| naive_block_assignment=use_naive, | ||
| even_Ks=(hidden_size % cfg["BLOCK_SIZE_K"] == 0), | ||
| num_warps=cfg["num_warps"], | ||
| num_stages=cfg["num_stages"], | ||
| ) | ||
|
Comment on lines
+2076
to
+2083
|
||
|
|
||
| # --- 4. SwiGLU activation --- | ||
| down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) | ||
|
|
||
| # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- | ||
| # Kernel loads topk_weights with flat offset (topk_weights_ptr + offs_token), | ||
| # which assumes contiguous row-major layout (stride[-1] == 1). | ||
| if not topk_weights.is_contiguous(): | ||
| topk_weights = topk_weights.contiguous() | ||
|
|
||
| # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn | ||
| down_proj_out = paddle.empty( | ||
| (num_token_expert_pairs, hidden_size), | ||
| dtype=x.dtype, | ||
| ) | ||
| # Reuse the same config and preprocess results as GEMM1. | ||
| # The preprocess output only depends on BLOCK_SIZE_M (the M-tile alignment), | ||
| # which is determined solely by token_num and is identical for both GEMMs. | ||
| # This matches vLLM's approach of using one config for both GEMMs. | ||
| if use_naive: | ||
| max_possible_num_post_padded_2 = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] | ||
| num_tokens_post_padded_2 = paddle.full([1], max_possible_num_post_padded_2, dtype="int32") | ||
| expert_ids_2 = expert_ids | ||
| sorted_token_ids_2 = expert_ids | ||
| else: | ||
| sorted_token_ids_2 = sorted_token_ids | ||
| expert_ids_2 = expert_ids | ||
| num_tokens_post_padded_2 = num_tokens_post_padded | ||
| max_possible_num_post_padded_2 = max_possible_num_post_padded | ||
| # Grid clipping for GEMM2 | ||
| if token_num < cfg["BLOCK_SIZE_M"]: | ||
| max_possible_num_post_padded_2 = min( | ||
| max_possible_num_post_padded_2, | ||
| token_num * top_k * cfg["BLOCK_SIZE_M"], | ||
| ) | ||
|
|
||
| grid2 = ( | ||
| ceil_div(max_possible_num_post_padded_2, cfg["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg["BLOCK_SIZE_N"]), | ||
| ) | ||
| fused_moe_kernel_bf16[grid2]( | ||
| down_proj_input, | ||
| layer.down_proj_weight, | ||
| down_proj_out, | ||
| topk_weights, | ||
| sorted_token_ids_2, | ||
| expert_ids_2, | ||
| num_tokens_post_padded_2, | ||
| N=hidden_size, | ||
| K=moe_intermediate_size, | ||
| EM=max_possible_num_post_padded_2, | ||
| num_valid_tokens=num_token_expert_pairs, | ||
| stride_am=down_proj_input.strides[0], | ||
| stride_ak=down_proj_input.strides[1], | ||
| stride_be=layer.down_proj_weight.strides[0], | ||
| stride_bk=layer.down_proj_weight.strides[1], | ||
| stride_bn=layer.down_proj_weight.strides[2], | ||
| stride_cm=down_proj_out.strides[0], | ||
| stride_cn=down_proj_out.strides[1], | ||
| BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], | ||
| BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], | ||
| BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], | ||
| GROUP_SIZE_M=cfg["GROUP_SIZE_M"], | ||
| MUL_ROUTED_WEIGHT=True, # fuse router weight * output | ||
| top_k=1, | ||
| compute_type=tl.bfloat16, | ||
| naive_block_assignment=use_naive, | ||
| even_Ks=(moe_intermediate_size % cfg["BLOCK_SIZE_K"] == 0), | ||
| num_warps=cfg["num_warps"], | ||
| num_stages=cfg["num_stages"], | ||
| ) | ||
|
|
||
| # --- 6. Reduce over topk --- | ||
| down_proj_out.reshape_([token_num, top_k, hidden_size]) | ||
| out = down_proj_out.sum(axis=1) | ||
| return out | ||
|
|
||
| def apply_ep_prefill( | ||
| self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None | ||
| ): | ||
| raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.") | ||
|
|
||
| def apply_ep_decode( | ||
| self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None | ||
| ): | ||
| raise NotImplementedError("TritonMoEMethod does not support EP decode yet.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
import triton.language as tl从原try/except ImportError保护块移到顶层,使得无 Triton 环境(如纯 CPU 测试、DCU 等)导入fused_moe_triton_backend.py模块时会直接抛ModuleNotFoundError,影响原有TritonWeightOnlyMoEMethod用户。建议保留兼容性写法:
或在
TritonMoEMethod.apply_tp内部懒导入。