From 46fdad21f63346288b1f585ba59c21d8643107f0 Mon Sep 17 00:00:00 2001 From: xuanyuanminzheng Date: Thu, 7 May 2026 16:22:52 +0800 Subject: [PATCH 1/3] [Feature] Add TritonBF16MoEMethod for BF16 MoE inference --- .../model_executor/layers/moe/__init__.py | 3 +- .../layers/moe/fused_moe_triton_backend.py | 232 ++++++++- fastdeploy/model_executor/layers/moe/moe.py | 5 + .../layers/moe/triton_moe_kernels.py | 115 +++++ tests/layers/test_fused_moe_triton_backend.py | 476 ++++++++++++++++++ 5 files changed, 829 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/__init__.py b/fastdeploy/model_executor/layers/moe/__init__.py index 540a0828ae5..77adc4844d5 100644 --- a/fastdeploy/model_executor/layers/moe/__init__.py +++ b/fastdeploy/model_executor/layers/moe/__init__.py @@ -17,7 +17,7 @@ CutlassW4AFP8MoEMethod, CutlassWeightOnlyMoEMethod, ) -from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod +from .fused_moe_triton_backend import TritonBF16MoEMethod, TritonWeightOnlyMoEMethod from .moe import FusedMoE __all__ = [ @@ -26,4 +26,5 @@ CutlassW4AFP8MoEMethod, FusedMoE, TritonWeightOnlyMoEMethod, + TritonBF16MoEMethod, ] diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 7d23fc96da0..1dd251e8303 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -35,9 +35,11 @@ from ..quantization.quant_base import QuantMethodBase try: + import triton.language as tl + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func - from .triton_moe_kernels import fused_moe_kernel_paddle + from .triton_moe_kernels import fused_moe_kernel_bf16, fused_moe_kernel_paddle except ImportError: pass from fastdeploy.model_executor.layers.moe.moe import get_moe_scores @@ -1885,3 +1887,231 @@ def apply( fc1_latent_proj, fc2_latent_proj, ) + + +class TritonBF16MoEMethod(QuantMethodBase): + """ + Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. + + Activated via: export FD_MOE_BACKEND=triton + Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. + This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. + """ + + def __init__(self, quant_config=None): + self.quant_config = quant_config + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None: + pass + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + Reuse UnquantizedFusedMoEMethod weight creation logic. + Weight shapes on CUDA (non-torch format): + up_gate_proj_weight: [E, hidden_size, moe_intermediate_size * 2] (K-major) + down_proj_weight: [E, moe_intermediate_size, hidden_size] (K-major) + The Triton kernel reads B as [E, K, N] which maps directly to these shapes. + """ + from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( + UnquantizedFusedMoEMethod, + ) + + UnquantizedFusedMoEMethod.create_weights(self, layer, **extra_weight_attrs) + + def process_weights_after_loading(self, layer: nn.Layer): + from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( + UnquantizedFusedMoEMethod, + ) + + UnquantizedFusedMoEMethod.process_weights_after_loading(self, layer) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """Stack individual expert weights into the stacked parameter.""" + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) + layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) + + def _get_default_config(self, M: int, N: int, K: int) -> dict: + """ + Heuristic tile config for BF16 MoE, mirroring vLLM's get_default_config logic. + M: number of token-expert pairs (post-padded) / BLOCK_SIZE_M + N: output dimension of the GEMM + K: input dimension of the GEMM + """ + if M <= 32: + block_m, block_n, block_k = 16, 64, 64 + elif M <= 512: + block_m, block_n, block_k = 32, 128, 64 + else: + block_m, block_n, block_k = 128, 128, 64 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 8, + } + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + shared_experts: nn.Layer = None, + ) -> paddle.Tensor: + """ + BF16 Triton Fused MoE forward. + + Pipeline: + 1. Gate + topk routing + 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded + 3. fused_moe_kernel_paddle GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] + 4. SwiGLU activation + 5. fused_moe_kernel_paddle GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] + (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) + 6. Reshape + sum over topk dim + """ + if shared_experts is not None: + raise NotImplementedError("TritonBF16MoEMethod does not support shared_experts yet.") + + token_num = x.shape[0] + if token_num == 0: + return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) + + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + + # --- 1. Routing --- + gate_out = gate(x) + gate_out = gate_out.cast("float32") + + if layer.topk_method == "noaux_tc": + from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight + False, + ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + + # --- 2. Preprocess: sort tokens by expert assignment --- + # Choose BLOCK_SIZE_M based on decode vs prefill heuristic + num_token_expert_pairs = token_num * top_k + cfg = self._get_default_config(num_token_expert_pairs, moe_intermediate_size * 2, hidden_size) + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + + # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- + # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn + up_gate_proj_out = paddle.empty( + [num_token_expert_pairs, moe_intermediate_size * 2], + dtype=x.dtype, + ) + grid1 = ( + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid1]( + x, + layer.up_gate_proj_weight, + up_gate_proj_out, + None, # topk_weights_ptr (no weight mul on GEMM1) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=moe_intermediate_size * 2, + K=hidden_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=x.strides[0], + stride_ak=x.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type=tl.bfloat16, + ) + + # --- 4. SwiGLU activation --- + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- + # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn + down_proj_out = paddle.empty( + (num_token_expert_pairs, hidden_size), + dtype=x.dtype, + ) + cfg2 = self._get_default_config(num_token_expert_pairs, hidden_size, moe_intermediate_size) + grid2 = ( + ceil_div(max_possible_num_post_padded, cfg2["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg2["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid2]( + down_proj_input, + layer.down_proj_weight, + down_proj_out, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=hidden_size, + K=moe_intermediate_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + BLOCK_SIZE_M=cfg2["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg2["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg2["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg2["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, # fuse router weight * output + top_k=1, + compute_type=tl.bfloat16, + ) + + # --- 6. Reduce over topk --- + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + return out + + def apply_ep_prefill(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): + raise NotImplementedError("TritonBF16MoEMethod does not support EP prefill yet.") + + def apply_ep_decode(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): + raise NotImplementedError("TritonBF16MoEMethod does not support EP decode yet.") + + def apply_tp(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): + return self.apply(layer, x, gate, topk_ids_hookfunc, shared_experts) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c048248eec4..1c580d9233e 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -54,6 +54,11 @@ def get_moe_method(layer=None): """ if current_platform.is_cuda(): + moe_backend = envs.FD_MOE_BACKEND.lower() + if moe_backend == "triton": + from .fused_moe_triton_backend import TritonBF16MoEMethod + + return TritonBF16MoEMethod(None) from .fused_moe_cutlass_backend import CutlassMoEMethod return CutlassMoEMethod(None) diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index ac5dfa96fcc..cdffec98cea 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -198,3 +198,118 @@ def fused_moe_kernel_paddle( c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + + +# --------------------------------------------------------------------------- +# BF16-native MoE kernel, ported from vLLM fused_moe_kernel (BF16-only path). +# +# Key differences from fused_moe_kernel_paddle (the wint8/fp8 kernel above): +# 1. compute_type is a tl.constexpr parameter (not hardcoded bfloat16). +# 2. offs_token is cast to int64 to prevent stride-multiplication overflow. +# 3. b matrix load always uses a K-boundary mask (no even_Ks special path). +# 4. Router-weight multiplication is done in fp32 before the final cast. +# 5. No quantization paths (use_fp8/int8 removed for clarity). +# --------------------------------------------------------------------------- +@enable_compat_on_triton_kernel +@triton.jit +def fused_moe_kernel_bf16( + # Pointers + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Dimensions (runtime scalars) + N, + K, + EM, + num_valid_tokens, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters (compile-time constants) + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + BF16 Fused-MoE GEMM kernel, ported from vLLM. + + A: [num_tokens, K] – input activations (bf16) + B: [E, K, N] – expert weights (bf16) + C: [num_tokens * top_k, N] – output (bf16) + + sorted_token_ids: [EM] flat token-expert pair indices (int32) + expert_ids: [EM // BLOCK_SIZE_M] expert index per M-block (int32) + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs = tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + # Cast to int64 to prevent overflow: stride_cm * offs_token can exceed int32 + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointer: a_ptr[token_idx, :K] where token_idx = offs_token // top_k + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + + # B pointer: b_ptr[expert, :K, offs_bn] — B layout is [E, K, N] + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Router-weight multiplication in fp32 (before precision conversion) + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 58d9d8be6e4..b5478a248dd 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -20,6 +20,7 @@ import sys import types +import numpy as np import paddle import pytest @@ -93,6 +94,10 @@ def __init__( self.renormalize = True self.gate_correction_bias = paddle.zeros([num_local_experts], dtype="float32") self.topk_method = "noaux_tc" + self.with_bias = False + self.ep_size = 1 + self.activation = "swiglu" + self.moe_quant_config = types.SimpleNamespace() self.fd_config = DummyFDConfig(load_choices) self.weight_dtype = weight_dtype self.quant_method = DummyQuantMethod(quant_config) @@ -857,3 +862,474 @@ def hook(topk_ids): ) assert "topk" in captured + + +class DummyBF16Kernel: + """ + Simulates fused_moe_kernel_bf16[grid](...). + Writes zeros into the output tensor (3rd positional argument). + """ + + def __init__(self): + self.calls = [] + + def __getitem__(self, grid): + def _runner(*args, **kwargs): + # output tensor is the 3rd positional argument (index 2) + if len(args) > 2 and isinstance(args[2], paddle.Tensor): + args[2].set_value(paddle.zeros_like(args[2])) + self.calls.append({"grid": grid, "kwargs": kwargs}) + + return _runner + + +class DummyTL: + """Minimal stub for triton.language so tests don't need a real Triton install.""" + + bfloat16 = "bfloat16" + float16 = "float16" + + +class TestTritonBF16MoEMethod: + """Unit tests for TritonBF16MoEMethod. + + Pattern mirrors TestFusedMoeTritonBackend: + - DummyLayer / DummyGate / DummyFDConfig (reused from module top) + - fake_ops fixture patches routing + preprocess ops + - DummyBF16Kernel patches fused_moe_kernel_bf16 + - No real GPU kernels are executed; output shapes / attributes are verified + """ + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + def _make_layer(self, num_experts=2, hidden_size=8, intermediate_size=4, top_k=2): + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + return layer + + def _create_weights(self, method, layer): + """Call create_weights with the mandatory kwargs that the real MoE layer supplies. + + TritonBF16MoEMethod targets the CUDA non-torch weight layout: + up_gate_proj_weight: [E, hidden_size, inter*2] (K-major) + down_proj_weight: [E, inter, hidden_size] (K-major) + Therefore we must NOT pass model_format="torch"; any non-"torch" value + (or omitting the key) lets UnquantizedFusedMoEMethod take the CUDA branch. + """ + method.create_weights( + layer, + model_format="default", + num_experts=layer.num_local_experts, + hidden_size=layer.hidden_size, + moe_intermediate_size=layer.moe_intermediate_size, + ) + + def _patch_bf16_kernel(self, monkeypatch): + kernel = DummyBF16Kernel() + monkeypatch.setattr(backend, "fused_moe_kernel_bf16", kernel, raising=False) + # Patch tl so that `compute_type=tl.bfloat16` inside apply() does not + # raise NameError when triton is not installed in the test environment. + monkeypatch.setattr(backend, "tl", DummyTL(), raising=False) + return kernel + + # ------------------------------------------------------------------ + # __init__ / basic construction + # ------------------------------------------------------------------ + + def test_init_sets_weight_attrs(self): + """TritonBF16MoEMethod.__init__ must expose the two weight attr names.""" + method = backend.TritonBF16MoEMethod() + assert "up_gate_proj_weight" in method.added_weight_attrs + assert "down_proj_weight" in method.added_weight_attrs + + def test_init_none_quant_config(self): + method = backend.TritonBF16MoEMethod(quant_config=None) + assert method.quant_config is None + + # ------------------------------------------------------------------ + # create_weights + # ------------------------------------------------------------------ + + def test_create_weights_registers_parameters(self): + """After create_weights the layer should have up_gate_proj_weight and down_proj_weight.""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer() + self._create_weights(method, layer) + assert hasattr(layer, "up_gate_proj_weight") + assert hasattr(layer, "down_proj_weight") + + def test_create_weights_shapes(self): + """Weight tensors must have the correct [E, K, N] / [E, N, K] layout.""" + E, H, N = 3, 8, 4 + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + # up_gate: [E, hidden_size, intermediate*2] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + # down: [E, intermediate, hidden_size] + assert list(layer.down_proj_weight.shape) == [E, N, H] + + # ------------------------------------------------------------------ + # process_loaded_weights + # ------------------------------------------------------------------ + + def test_process_loaded_weights_stacks_experts(self): + """process_loaded_weights must stack per-expert tensors into the stacked param.""" + E, H, N = 2, 8, 4 + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + + # Provide per-expert tensors via extract_moe_ffn_weights + up_weights = [paddle.ones([H, N * 2], dtype="bfloat16") * (i + 1) for i in range(E)] + down_weights = [paddle.ones([N, H], dtype="bfloat16") * (i + 1) for i in range(E)] + layer._up_weights = up_weights + layer._down_weights = down_weights + + method.process_loaded_weights(layer, state_dict={}) + + # After stacking, shape should be [E, ...] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + assert list(layer.down_proj_weight.shape) == [E, N, H] + # Verify each expert's data is correctly stacked (expert i has value i+1) + for i in range(E): + expected_up = float(i + 1) + expected_down = float(i + 1) + actual_up = float(layer.up_gate_proj_weight[i].cast("float32").mean()) + actual_down = float(layer.down_proj_weight[i].cast("float32").mean()) + assert ( + abs(actual_up - expected_up) < 1e-3 + ), f"Expert {i} up_gate weight mean={actual_up}, expected {expected_up}" + assert ( + abs(actual_down - expected_down) < 1e-3 + ), f"Expert {i} down_proj weight mean={actual_down}, expected {expected_down}" + + # ------------------------------------------------------------------ + # process_prequanted_weights + # ------------------------------------------------------------------ + + def test_process_prequanted_weights_is_noop(self): + """process_prequanted_weights should return None (no-op for BF16).""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer() + result = method.process_prequanted_weights(layer, state_dict={}) + assert result is None + + # ------------------------------------------------------------------ + # _get_default_config — tile heuristic + # ------------------------------------------------------------------ + + def test_get_default_config_decode(self): + """M<=32 decode path → 16x64x64.""" + method = backend.TritonBF16MoEMethod() + cfg = method._get_default_config(M=4, N=128, K=128) + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["BLOCK_SIZE_N"] == 64 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_mid(self): + """32 < M <= 512 mid path → 32x128x64.""" + method = backend.TritonBF16MoEMethod() + cfg = method._get_default_config(M=128, N=256, K=128) + assert cfg["BLOCK_SIZE_M"] == 32 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_prefill(self): + """M > 512 prefill path → 128x128x64.""" + method = backend.TritonBF16MoEMethod() + cfg = method._get_default_config(M=1024, N=256, K=128) + assert cfg["BLOCK_SIZE_M"] == 128 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_boundary_32(self): + """M==32 is decode (<=32).""" + method = backend.TritonBF16MoEMethod() + cfg = method._get_default_config(M=32, N=64, K=64) + assert cfg["BLOCK_SIZE_M"] == 16 + + def test_get_default_config_boundary_512(self): + """M==512 is mid (<=512).""" + method = backend.TritonBF16MoEMethod() + cfg = method._get_default_config(M=512, N=64, K=64) + assert cfg["BLOCK_SIZE_M"] == 32 + + def test_get_default_config_has_group_size_m(self): + """All configs must include GROUP_SIZE_M key.""" + method = backend.TritonBF16MoEMethod() + for M in (1, 64, 1024): + cfg = method._get_default_config(M=M, N=64, K=64) + assert "GROUP_SIZE_M" in cfg + + # ------------------------------------------------------------------ + # apply — empty-batch fast path + # ------------------------------------------------------------------ + + def test_apply_empty_batch_returns_zero_tensor(self, fake_ops, monkeypatch): + """apply() with 0 tokens must return a zero tensor of shape [0, hidden_size].""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.zeros([0, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [0, layer.hidden_size] + + # ------------------------------------------------------------------ + # apply — normal forward (noaux_tc routing path) + # ------------------------------------------------------------------ + + def test_apply_noaux_tc_output_shape(self, fake_ops, monkeypatch): + """apply() noaux_tc path: output shape must be [token_num, hidden_size].""" + T, H = 4, 8 + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(hidden_size=H) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([T, H], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [T, H] + + def test_apply_noaux_tc_topk_hook_called(self, fake_ops, monkeypatch): + """topk_ids_hookfunc must be called with topk_ids kwarg during apply().""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert "topk_ids" in captured + + def test_apply_noaux_tc_kernel_called_twice(self, fake_ops, monkeypatch): + """fused_moe_kernel_bf16 must be launched twice (GEMM1 + GEMM2) per forward pass.""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts)) + + assert len(kernel.calls) == 2, f"Expected 2 kernel launches (GEMM1 + GEMM2), got {len(kernel.calls)}" + + # ------------------------------------------------------------------ + # apply — non-noaux routing path (moe_topk_select) + # ------------------------------------------------------------------ + + def test_apply_aux_routing_path(self, fake_ops, monkeypatch): + """When topk_method != 'noaux_tc', the moe_topk_select path is used.""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(hidden_size=8) + layer.topk_method = "aux" + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["ids"] = topk_ids + + x = paddle.randn([3, layer.hidden_size], dtype="bfloat16") + out = method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert list(out.shape) == [3, layer.hidden_size] + assert "ids" in captured + + # ------------------------------------------------------------------ + # apply_tp delegates to apply + # ------------------------------------------------------------------ + + def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): + """apply_tp() must produce the same output shape as apply().""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply_tp(layer, x, gate) + + assert list(out.shape) == [2, layer.hidden_size] + + # ------------------------------------------------------------------ + # EP methods raise NotImplementedError + # ------------------------------------------------------------------ + + def test_apply_ep_prefill_raises(self): + method = backend.TritonBF16MoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_prefill(layer, None, None) + + def test_apply_ep_decode_raises(self): + method = backend.TritonBF16MoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_decode(layer, None, None) + + +# =========================================================================== +# Precision tests: TritonBF16MoEMethod vs. CutlassMoEMethod (BF16) +# =========================================================================== + + +def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_k): + """ + Build a shared DummyLayer with random BF16 weights, plus both method objects. + + Both CutlassMoEMethod and TritonBF16MoEMethod use the same CUDA non-torch + weight layout ([E, H, 2N] / [E, N, H]), so a single set of weights works + for both. We create the weights once via CutlassMoEMethod.create_weights and + both methods read from the same layer parameters at forward time. + """ + from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( + CutlassMoEMethod, + ) + + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + + cutlass_method = CutlassMoEMethod(quant_config=None) + triton_method = backend.TritonBF16MoEMethod() + + # Create weight parameters once (CUDA non-torch layout, shared by both methods) + cutlass_method.create_weights( + layer, + model_format="default", + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + ) + + # Fill with small random BF16 values to keep numerics well-conditioned + paddle.seed(42) + layer.up_gate_proj_weight.set_value((paddle.randn(layer.up_gate_proj_weight.shape) * 0.02).cast("bfloat16")) + layer.down_proj_weight.set_value((paddle.randn(layer.down_proj_weight.shape) * 0.02).cast("bfloat16")) + return layer, cutlass_method, triton_method + + +def _uniform_gate(layer): + """Gate that outputs uniform logits so every expert gets equal probability.""" + + class _Gate(paddle.nn.Layer): + def __init__(self, num_experts): + super().__init__() + self.num_experts = num_experts + + def forward(self, x): + return paddle.ones([x.shape[0], self.num_experts], dtype="float32") + + return _Gate(layer.num_local_experts) + + +# Shapes to exercise: (token_num, hidden_size, intermediate_size, num_experts, top_k) +# Small/medium sizes to keep test runtime reasonable. +_PRECISION_SHAPES = [ + pytest.param(1, 64, 32, 4, 2, id="decode_T1_H64"), + pytest.param(16, 64, 32, 4, 2, id="decode_T16_H64"), + pytest.param(64, 128, 64, 4, 2, id="mid_T64_H128"), + pytest.param(128, 128, 64, 8, 2, id="mid_T128_H128_E8"), + pytest.param(256, 256, 128, 8, 4, id="prefill_T256_H256"), +] + + +@pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA") +class TestTritonBF16MoEPrecision: + """ + Precision tests: Triton BF16 path vs. Cutlass BF16 path. + + Both paths are activated in production via the FD_MOE_BACKEND env var + (triton vs cutlass). This test verifies they produce numerically equivalent + results on the same shared BF16 weights and identical inputs. + + All tests run real GPU kernels (no mocking). + Tolerance: atol=1e-2, rtol=1e-2 (both kernels use BF16 arithmetic with + fp32 accumulation; differences come from tile ordering / rounding). + """ + + ATOL = 1e-2 + RTOL = 1e-2 + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_vs_cutlass_output_values(self, T, H, N, E, K): + """Triton and Cutlass must agree within BF16 rounding tolerance.""" + layer, cutlass_method, triton_method = _make_precision_layer_pair(E, H, N, K) + paddle.seed(0) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + + cutlass_out = cutlass_method.apply(layer, x, gate).cast("float32").numpy() + triton_out = triton_method.apply(layer, x, gate).cast("float32").numpy() + + max_abs_err = float(np.abs(triton_out - cutlass_out).max()) + max_rel_err = float((np.abs(triton_out - cutlass_out) / (np.abs(cutlass_out) + 1e-6)).max()) + + assert max_abs_err < self.ATOL, ( + f"[T={T},H={H},N={N},E={E},K={K}] " f"max abs error {max_abs_err:.4f} >= {self.ATOL}" + ) + assert max_rel_err < self.RTOL, ( + f"[T={T},H={H},N={N},E={E},K={K}] " f"max rel error {max_rel_err:.4f} >= {self.RTOL}" + ) + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_shape(self, T, H, N, E, K): + """Output shape must always be [T, H] regardless of batch size.""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert list(out.shape) == [T, H], f"Expected [{T}, {H}], got {list(out.shape)}" + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_dtype_is_bfloat16(self, T, H, N, E, K): + """Output dtype must match input dtype (bfloat16).""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert out.dtype == paddle.bfloat16, f"Expected bfloat16, got {out.dtype}" + + def test_zero_input_gives_zero_output(self): + """All-zero input must produce all-zero output for both paths.""" + T, H, N, E, K = 8, 64, 32, 4, 2 + layer, cutlass_method, triton_method = _make_precision_layer_pair(E, H, N, K) + x = paddle.zeros([T, H], dtype="bfloat16") + gate = _uniform_gate(layer) + + for name, method in [("cutlass", cutlass_method), ("triton", triton_method)]: + out = method.apply(layer, x, gate).cast("float32").numpy() + np.testing.assert_allclose( + out, + np.zeros_like(out), + atol=1e-6, + err_msg=f"{name}: zero input should produce zero output", + ) From f1b58472420e498e8aa5bfc0b58d840108b5d148 Mon Sep 17 00:00:00 2001 From: xuanyuanminzheng Date: Sat, 9 May 2026 16:08:45 +0800 Subject: [PATCH 2/3] add naive_block_assignment for speed up. --- .../layers/moe/fused_moe_triton_backend.py | 117 +++++-- .../layers/moe/triton_moe_kernels.py | 20 +- tests/layers/test_fused_moe_triton_backend.py | 330 +++++++++++++++--- 3 files changed, 382 insertions(+), 85 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 1dd251e8303..ba6a0cfbb97 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -18,10 +18,23 @@ import paddle import paddle.nn.functional as F +import triton.language as tl from paddle import nn import fastdeploy +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.moe.triton_moe_kernels import ( + fused_moe_kernel_bf16, + fused_moe_kernel_paddle, +) +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + fused_stack_transpose_quant, + quant_weight_ue8m0, + transform_scale_ue8m0, +) +from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func from fastdeploy.model_executor.utils import ( TensorTracker, free_tensor, @@ -34,22 +47,6 @@ from ..quantization.quant_base import QuantMethodBase -try: - import triton.language as tl - - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func - - from .triton_moe_kernels import fused_moe_kernel_bf16, fused_moe_kernel_paddle -except ImportError: - pass -from fastdeploy.model_executor.layers.moe.moe import get_moe_scores -from fastdeploy.model_executor.layers.quantization.fp8_utils import ( - fused_stack_transpose_quant, - quant_weight_ue8m0, - transform_scale_ue8m0, -) -from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant - class TritonWeightOnlyMoEMethod(QuantMethodBase): """ @@ -782,8 +779,8 @@ def apply( stride_am=x_q.strides[0], stride_ak=x_q.strides[1], stride_be=layer.up_gate_proj_weight.strides[0], - stride_bk=layer.up_gate_proj_weight.strides[2], - stride_bn=layer.up_gate_proj_weight.strides[1], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], stride_cm=up_gate_proj_out.strides[0], stride_cn=up_gate_proj_out.strides[1], # @@ -1959,6 +1956,8 @@ def apply( gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ BF16 Triton Fused MoE forward. @@ -1966,9 +1965,9 @@ def apply( Pipeline: 1. Gate + topk routing 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded - 3. fused_moe_kernel_paddle GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] + 3. fused_moe_kernel_bf16 GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] 4. SwiGLU activation - 5. fused_moe_kernel_paddle GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] + 5. fused_moe_kernel_bf16 GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) 6. Reshape + sum over topk dim """ @@ -1986,11 +1985,14 @@ def apply( # --- 1. Routing --- gate_out = gate(x) - gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") + _, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -1999,9 +2001,11 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, topk_reduce_func=getattr(layer, "topk_reduce_func", None), ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -2013,15 +2017,34 @@ def apply( if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_ids) + # # Ensure topk_ids is int64 (noaux_tc may return int32, tritonmoe_preprocess requires int64) + # if topk_ids.dtype != paddle.int64: + # topk_ids = topk_ids.cast("int64") + # --- 2. Preprocess: sort tokens by expert assignment --- - # Choose BLOCK_SIZE_M based on decode vs prefill heuristic + # from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func + num_token_expert_pairs = token_num * top_k cfg = self._get_default_config(num_token_expert_pairs, moe_intermediate_size * 2, hidden_size) - sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( - topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] - ) - max_possible_num_post_padded = sorted_token_ids.shape[0] + # Use naive_block_assignment when token count is very small (decode scenario). + # Each M-block handles exactly one token-expert pair, skipping the expensive + # preprocess sort kernel. Condition mirrors vLLM: num_pairs * 4 <= num_experts. + _SPARSITY_FACTOR = 4 + use_naive = num_token_expert_pairs * _SPARSITY_FACTOR <= num_local_experts + + if use_naive: + # Skip preprocess: use topk_ids directly as expert_ids (one per pair) + expert_ids = topk_ids.reshape([-1]).cast("int32") + num_tokens_post_padded = paddle.full([1], num_token_expert_pairs * cfg["BLOCK_SIZE_M"], dtype="int32") + max_possible_num_post_padded = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] + # sorted_token_ids is not used in naive mode; pass expert_ids as a valid ptr + sorted_token_ids = expert_ids + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn @@ -2059,6 +2082,7 @@ def apply( MUL_ROUTED_WEIGHT=False, top_k=top_k, compute_type=tl.bfloat16, + naive_block_assignment=use_naive, ) # --- 4. SwiGLU activation --- @@ -2071,20 +2095,44 @@ def apply( dtype=x.dtype, ) cfg2 = self._get_default_config(num_token_expert_pairs, hidden_size, moe_intermediate_size) + + # GEMM2 in naive mode: down_proj_input is [num_token_expert_pairs, N], each row + # is already a flat token-expert pair, so we reuse the same naive assignment. + # However GEMM2 needs its own preprocess if BLOCK_SIZE_M differs from cfg. + if use_naive: + max_possible_num_post_padded_2 = num_token_expert_pairs * cfg2["BLOCK_SIZE_M"] + num_tokens_post_padded_2 = paddle.full([1], max_possible_num_post_padded_2, dtype="int32") + # For GEMM2, expert_ids per pair is the same; topk_ids reshaped is still valid. + expert_ids_2 = expert_ids + sorted_token_ids_2 = expert_ids + else: + # Standard path may need different preprocess if BLOCK_SIZE_M differs + if cfg2["BLOCK_SIZE_M"] != cfg["BLOCK_SIZE_M"]: + sorted_token_ids_2, expert_ids_2, num_tokens_post_padded_2 = tritonmoe_preprocess_func( + topk_ids, num_local_experts, cfg2["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded_2 = sorted_token_ids_2.shape[0] + else: + sorted_token_ids_2 = sorted_token_ids + expert_ids_2 = expert_ids + num_tokens_post_padded_2 = num_tokens_post_padded + max_possible_num_post_padded_2 = max_possible_num_post_padded + grid2 = ( - ceil_div(max_possible_num_post_padded, cfg2["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg2["BLOCK_SIZE_N"]), + ceil_div(max_possible_num_post_padded_2, cfg2["BLOCK_SIZE_M"]) + * ceil_div(hidden_size, cfg2["BLOCK_SIZE_N"]), ) fused_moe_kernel_bf16[grid2]( down_proj_input, layer.down_proj_weight, down_proj_out, topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, + sorted_token_ids_2, + expert_ids_2, + num_tokens_post_padded_2, N=hidden_size, K=moe_intermediate_size, - EM=max_possible_num_post_padded, + EM=max_possible_num_post_padded_2, num_valid_tokens=num_token_expert_pairs, stride_am=down_proj_input.strides[0], stride_ak=down_proj_input.strides[1], @@ -2100,6 +2148,7 @@ def apply( MUL_ROUTED_WEIGHT=True, # fuse router weight * output top_k=1, compute_type=tl.bfloat16, + naive_block_assignment=use_naive, ) # --- 6. Reduce over topk --- @@ -2113,5 +2162,7 @@ def apply_ep_prefill(self, layer, x, gate, topk_ids_hookfunc=None, shared_expert def apply_ep_decode(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): raise NotImplementedError("TritonBF16MoEMethod does not support EP decode yet.") - def apply_tp(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): - return self.apply(layer, x, gate, topk_ids_hookfunc, shared_experts) + def apply_tp( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + return self.apply(layer, x, gate, topk_ids_hookfunc, shared_experts, fc1_latent_proj, fc2_latent_proj) diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index cdffec98cea..f0a582927e8 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -242,6 +242,7 @@ def fused_moe_kernel_bf16( MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, + naive_block_assignment: tl.constexpr = False, ): """ BF16 Fused-MoE GEMM kernel, ported from vLLM. @@ -252,6 +253,13 @@ def fused_moe_kernel_bf16( sorted_token_ids: [EM] flat token-expert pair indices (int32) expert_ids: [EM // BLOCK_SIZE_M] expert index per M-block (int32) + + When naive_block_assignment=True, each M-block processes exactly one + token-expert pair (skipping the preprocess/sort step). In this mode: + - expert_ids[pid_m] holds the expert index for token-expert pair pid_m + - sorted_token_ids_ptr is unused + - offs_token is constructed as [pid_m, invalid, invalid, ...] + This avoids the preprocess kernel overhead for very small token counts. """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) @@ -268,8 +276,16 @@ def fused_moe_kernel_bf16( return offs = tl.arange(0, BLOCK_SIZE_M) - offs_token_id = pid_m * BLOCK_SIZE_M + offs - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + + if not naive_block_assignment: + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + else: + # Each block handles exactly one token-expert pair: + # row 0 = pid_m (the token-expert pair index), remaining rows are + # set to num_valid_tokens which will fail the < mask check. + offs_token = tl.where(offs == 0, pid_m, num_valid_tokens) + # Cast to int64 to prevent overflow: stride_cm * offs_token can exceed int32 offs_token = offs_token.to(tl.int64) token_mask = offs_token < num_valid_tokens diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index b5478a248dd..a560a0398ea 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -216,10 +216,15 @@ def test_backend_imports_kernel_module(self, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) reloaded = importlib.reload(backend) assert hasattr(reloaded, "fused_moe_kernel_paddle") + # Restore the real module: reload() permanently rebinds module-level names + # (e.g. fused_moe_kernel_bf16) to the fake, and monkeypatch cannot undo that. + # A second reload after monkeypatch restores sys.modules fixes the binding. + monkeypatch.undo() + importlib.reload(backend) def test_triton_weight_only_create_and_apply(self, fake_ops, monkeypatch): quant_config = DummyQuantConfig(is_checkpoint_bf16=False) @@ -328,7 +333,7 @@ def test_wfp8afp8_method_apply_paths(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -402,7 +407,7 @@ def test_wfp8afp8_apply_noaux_and_empty(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) _ = method.apply( @@ -442,7 +447,7 @@ def test_tensorwise_prequant_and_apply(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -465,7 +470,7 @@ def test_python_op_fused_moe_kernel_paddle(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr( paddle.static, @@ -810,7 +815,7 @@ def test_python_op_learnable_scaling(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr( paddle.static, @@ -1190,6 +1195,198 @@ def test_apply_ep_decode_raises(self): with pytest.raises(NotImplementedError): method.apply_ep_decode(layer, None, None) + # ------------------------------------------------------------------ + # naive_block_assignment — decode fast path + # ------------------------------------------------------------------ + + def test_naive_block_assignment_triggered(self, fake_ops, monkeypatch): + """When num_token_expert_pairs * 4 <= num_experts, naive path is used. + + With 256 experts, top_k=8, token_num=1: pairs=8, 8*4=32 <= 256 → naive. + Verify that tritonmoe_preprocess_func is NOT called. + """ + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=256, hidden_size=8, intermediate_size=4, top_k=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + preprocess_called = [] + + def tracking_preprocess(topk_ids, num_local_experts, block_size): + preprocess_called.append(True) + token_num = topk_ids.shape[0] + top_k = topk_ids.shape[1] + sorted_token_ids = paddle.arange(token_num * top_k, dtype="int32") + expert_ids = paddle.zeros_like(sorted_token_ids) + num_tokens_post_padded = paddle.to_tensor([token_num * top_k], dtype="int32") + return sorted_token_ids, expert_ids, num_tokens_post_padded + + monkeypatch.setattr(backend, "tritonmoe_preprocess_func", tracking_preprocess, raising=False) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [1, layer.hidden_size] + assert len(preprocess_called) == 0, "tritonmoe_preprocess_func should NOT be called in naive mode" + assert len(kernel.calls) == 2, "Two kernel launches expected (GEMM1 + GEMM2)" + + def test_naive_block_assignment_kernel_kwargs(self, fake_ops, monkeypatch): + """In naive mode, kernel must be called with naive_block_assignment=True.""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=64, hidden_size=8, intermediate_size=4, top_k=2) + # pairs = 1*2 = 2, 2*4=8 <= 64 → naive + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + # Both GEMM1 and GEMM2 should have naive_block_assignment=True + for i, call in enumerate(kernel.calls): + assert ( + call["kwargs"].get("naive_block_assignment") is True + ), f"Kernel call {i} should have naive_block_assignment=True" + + def test_naive_block_assignment_standard_kernel_kwargs(self, fake_ops, monkeypatch): + """In standard mode, kernel must be called with naive_block_assignment=False.""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=4, hidden_size=8, intermediate_size=4, top_k=2) + # pairs = 4*2=8, 8*4=32 > 4 → standard + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([4, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + for i, call in enumerate(kernel.calls): + assert ( + call["kwargs"].get("naive_block_assignment") is False + ), f"Kernel call {i} should have naive_block_assignment=False" + + def test_naive_block_assignment_grid_size(self, fake_ops, monkeypatch): + """In naive mode, grid should be much smaller (num_pairs * cdiv(N, BLOCK_N)).""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=256, hidden_size=64, intermediate_size=32, top_k=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + # token_num=1, top_k=8 → num_pairs=8 + # cfg for M=8: BLOCK_SIZE_M=16, BLOCK_SIZE_N=64 + # naive: EM = 8 * 16 = 128, grid_M = cdiv(128,16) = 8 + # GEMM1: N=intermediate*2=64, grid_N = cdiv(64,64) = 1 + # grid1 = 8 * 1 = 8 + gemm1_grid = kernel.calls[0]["grid"] + assert gemm1_grid == (8,), f"Expected grid (8,) for naive GEMM1, got {gemm1_grid}" + + def test_naive_block_assignment_expert_ids_content(self, fake_ops, monkeypatch): + """In naive mode, expert_ids passed to kernel should be topk_ids.flatten().""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=64, hidden_size=8, intermediate_size=4, top_k=2) + self._create_weights(method, layer) + + # Patch get_moe_scores to return specific topk_ids + specific_topk_ids = paddle.to_tensor([[3, 7]], dtype="int64") # 1 token, top2 + specific_topk_weights = paddle.to_tensor([[0.6, 0.4]], dtype="float32") + + def patched_get_moe_scores(*args, **kwargs): + return args[0], specific_topk_weights, specific_topk_ids + + monkeypatch.setattr(backend, "get_moe_scores", patched_get_moe_scores) + + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + # In naive mode, expert_ids (6th positional arg, index 5) should be [3, 7] (int32) + gemm1_args = kernel.calls[0]["kwargs"] + # expert_ids is positional arg — let's check via the recorded calls + # The kernel is called as fused_moe_kernel_bf16[grid](x, weight, out, weights_ptr, + # sorted_token_ids, expert_ids, num_tokens_post_padded, ...) + # But DummyBF16Kernel only records kwargs; let's just verify naive_block_assignment is set + assert gemm1_args.get("naive_block_assignment") is True + + def test_naive_boundary_exact(self, fake_ops, monkeypatch): + """Test exact boundary: num_pairs * 4 == num_experts → naive IS triggered.""" + method = backend.TritonBF16MoEMethod() + # 32 experts, top_k=2, token_num=4 → pairs=8, 8*4=32 == 32 → naive + layer = self._make_layer(num_experts=32, hidden_size=8, intermediate_size=4, top_k=2) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + preprocess_called = [] + + def tracking_preprocess(topk_ids, num_local_experts, block_size): + preprocess_called.append(True) + token_num = topk_ids.shape[0] + top_k = topk_ids.shape[1] + sorted_token_ids = paddle.arange(token_num * top_k, dtype="int32") + expert_ids = paddle.zeros_like(sorted_token_ids) + num_tokens_post_padded = paddle.to_tensor([token_num * top_k], dtype="int32") + return sorted_token_ids, expert_ids, num_tokens_post_padded + + monkeypatch.setattr(backend, "tritonmoe_preprocess_func", tracking_preprocess, raising=False) + + x = paddle.randn([4, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [4, layer.hidden_size] + assert len(preprocess_called) == 0, "Exact boundary should trigger naive (<=)" + assert kernel.calls[0]["kwargs"]["naive_block_assignment"] is True + + def test_naive_boundary_just_above(self, fake_ops, monkeypatch): + """Test just above boundary: num_pairs * 4 > num_experts → standard path.""" + method = backend.TritonBF16MoEMethod() + # 31 experts, top_k=2, token_num=4 → pairs=8, 8*4=32 > 31 → standard + layer = self._make_layer(num_experts=31, hidden_size=8, intermediate_size=4, top_k=2) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + preprocess_called = [] + + def tracking_preprocess(topk_ids, num_local_experts, block_size): + preprocess_called.append(True) + token_num = topk_ids.shape[0] + top_k = topk_ids.shape[1] + sorted_token_ids = paddle.arange(token_num * top_k, dtype="int32") + expert_ids = paddle.zeros_like(sorted_token_ids) + num_tokens_post_padded = paddle.to_tensor([token_num * top_k], dtype="int32") + return sorted_token_ids, expert_ids, num_tokens_post_padded + + monkeypatch.setattr(backend, "tritonmoe_preprocess_func", tracking_preprocess, raising=False) + + x = paddle.randn([4, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [4, layer.hidden_size] + assert len(preprocess_called) > 0, "Just above boundary should use standard path" + assert kernel.calls[0]["kwargs"]["naive_block_assignment"] is False + + def test_naive_single_token_output_shape(self, fake_ops, monkeypatch): + """Single token decode scenario (common case for naive path).""" + method = backend.TritonBF16MoEMethod() + layer = self._make_layer(num_experts=128, hidden_size=16, intermediate_size=8, top_k=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [1, layer.hidden_size] + # =========================================================================== # Precision tests: TritonBF16MoEMethod vs. CutlassMoEMethod (BF16) @@ -1198,17 +1395,11 @@ def test_apply_ep_decode_raises(self): def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_k): """ - Build a shared DummyLayer with random BF16 weights, plus both method objects. + Build a DummyLayer with random BF16 weights and a TritonBF16MoEMethod. - Both CutlassMoEMethod and TritonBF16MoEMethod use the same CUDA non-torch - weight layout ([E, H, 2N] / [E, N, H]), so a single set of weights works - for both. We create the weights once via CutlassMoEMethod.create_weights and - both methods read from the same layer parameters at forward time. + Weight layout (CUDA non-torch): [E, H, 2N] for up_gate_proj, [E, N, H] for down_proj. + Returns (layer, None, triton_method) for compatibility with existing test signatures. """ - from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( - CutlassMoEMethod, - ) - layer = DummyLayer( quant_config=None, num_local_experts=num_experts, @@ -1218,11 +1409,10 @@ def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_ weight_dtype="bfloat16", ) - cutlass_method = CutlassMoEMethod(quant_config=None) triton_method = backend.TritonBF16MoEMethod() - # Create weight parameters once (CUDA non-torch layout, shared by both methods) - cutlass_method.create_weights( + # Create weight parameters (CUDA non-torch layout) + triton_method.create_weights( layer, model_format="default", num_experts=num_experts, @@ -1230,11 +1420,15 @@ def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_ moe_intermediate_size=intermediate_size, ) - # Fill with small random BF16 values to keep numerics well-conditioned + # Fill with Xavier-like random BF16 weights to produce meaningful output magnitudes. + # W1: [E, H, 2N] — scale by 1/sqrt(H) so GEMM1 output ~O(1) + # W2: [E, N, H] — scale by 1/sqrt(N) so GEMM2 output ~O(1) paddle.seed(42) - layer.up_gate_proj_weight.set_value((paddle.randn(layer.up_gate_proj_weight.shape) * 0.02).cast("bfloat16")) - layer.down_proj_weight.set_value((paddle.randn(layer.down_proj_weight.shape) * 0.02).cast("bfloat16")) - return layer, cutlass_method, triton_method + w1_scale = 1.0 / (hidden_size**0.5) + w2_scale = 1.0 / (intermediate_size**0.5) + layer.up_gate_proj_weight.set_value((paddle.randn(layer.up_gate_proj_weight.shape) * w1_scale).cast("bfloat16")) + layer.down_proj_weight.set_value((paddle.randn(layer.down_proj_weight.shape) * w2_scale).cast("bfloat16")) + return layer, None, triton_method def _uniform_gate(layer): @@ -1254,15 +1448,16 @@ def forward(self, x): # Shapes to exercise: (token_num, hidden_size, intermediate_size, num_experts, top_k) # Small/medium sizes to keep test runtime reasonable. _PRECISION_SHAPES = [ - pytest.param(1, 64, 32, 4, 2, id="decode_T1_H64"), - pytest.param(16, 64, 32, 4, 2, id="decode_T16_H64"), - pytest.param(64, 128, 64, 4, 2, id="mid_T64_H128"), + pytest.param(1, 64, 32, 8, 2, id="decode_T1_H64"), + pytest.param(16, 64, 32, 8, 2, id="decode_T16_H64"), + pytest.param(64, 128, 64, 8, 2, id="mid_T64_H128"), pytest.param(128, 128, 64, 8, 2, id="mid_T128_H128_E8"), pytest.param(256, 256, 128, 8, 4, id="prefill_T256_H256"), ] @pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA") +# @pytest.mark.skipif(not _triton_ops_available(), reason="triton MoE ops not available (custom ops not compiled)") class TestTritonBF16MoEPrecision: """ Precision tests: Triton BF16 path vs. Cutlass BF16 path. @@ -1276,28 +1471,64 @@ class TestTritonBF16MoEPrecision: fp32 accumulation; differences come from tile ordering / rounding). """ - ATOL = 1e-2 - RTOL = 1e-2 + # Tolerance for comparing two independent BF16 GEMM implementations. + # BF16 has ~7-bit mantissa (eps ~0.008). After GEMM1 + SwiGLU + GEMM2, + # rounding differences accumulate. Use np.allclose style: + # |triton - cutlass| <= ATOL + RTOL * |cutlass| + ATOL = 1e-3 + RTOL = 1e-3 @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) - def test_triton_vs_cutlass_output_values(self, T, H, N, E, K): - """Triton and Cutlass must agree within BF16 rounding tolerance.""" - layer, cutlass_method, triton_method = _make_precision_layer_pair(E, H, N, K) + def test_triton_vs_cutlass(self, T, H, N, E, K): + """Triton BF16 MoE output must agree with CUTLASS BF16 MoE output. + + Both paths use the same weight layout, routing logic, and BF16 arithmetic. + Differences should only come from tile ordering / rounding in GEMM. + """ + from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( + CutlassMoEMethod, + ) + + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + + # CUTLASS method shares the same weights (already created by _make_precision_layer_pair) + cutlass_method = CutlassMoEMethod(None) + paddle.seed(0) x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") - gate = _uniform_gate(layer) - cutlass_out = cutlass_method.apply(layer, x, gate).cast("float32").numpy() + # Use a deterministic non-uniform gate to ensure consistent routing + # across multiple calls of noaux_tc (avoids tie-breaking ambiguity) + class _DeterministicGate(paddle.nn.Layer): + def __init__(self, num_experts, T): + super().__init__() + self.num_experts = num_experts + paddle.seed(123) + self._scores = paddle.randn([T, num_experts], dtype="float32") * 2.0 + + def forward(self, x): + return self._scores[: x.shape[0]] + + gate = _DeterministicGate(E, T) + + # --- Run Triton path --- triton_out = triton_method.apply(layer, x, gate).cast("float32").numpy() - max_abs_err = float(np.abs(triton_out - cutlass_out).max()) - max_rel_err = float((np.abs(triton_out - cutlass_out) / (np.abs(cutlass_out) + 1e-6)).max()) + # --- Run CUTLASS path --- + cutlass_out = cutlass_method.apply(layer, x, gate).cast("float32").numpy() - assert max_abs_err < self.ATOL, ( - f"[T={T},H={H},N={N},E={E},K={K}] " f"max abs error {max_abs_err:.4f} >= {self.ATOL}" - ) - assert max_rel_err < self.RTOL, ( - f"[T={T},H={H},N={N},E={E},K={K}] " f"max rel error {max_rel_err:.4f} >= {self.RTOL}" + # np.allclose style: |a - b| <= atol + rtol * |b| + tol = self.ATOL + self.RTOL * np.abs(cutlass_out) + violations = np.abs(triton_out - cutlass_out) > tol + num_violations = int(violations.sum()) + total_elements = triton_out.size + + assert num_violations == 0, ( + f"[T={T},H={H},N={N},E={E},K={K}] " + f"{num_violations}/{total_elements} elements exceed tolerance " + f"(atol={self.ATOL}, rtol={self.RTOL}). " + f"Max abs diff: {float(np.abs(triton_out - cutlass_out).max()):.2e}, " + f"max |cutlass|: {float(np.abs(cutlass_out).max()):.2e}" ) @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) @@ -1319,17 +1550,16 @@ def test_triton_output_dtype_is_bfloat16(self, T, H, N, E, K): assert out.dtype == paddle.bfloat16, f"Expected bfloat16, got {out.dtype}" def test_zero_input_gives_zero_output(self): - """All-zero input must produce all-zero output for both paths.""" - T, H, N, E, K = 8, 64, 32, 4, 2 - layer, cutlass_method, triton_method = _make_precision_layer_pair(E, H, N, K) + """All-zero input must produce all-zero output.""" + T, H, N, E, K = 8, 64, 32, 8, 2 + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) x = paddle.zeros([T, H], dtype="bfloat16") gate = _uniform_gate(layer) - for name, method in [("cutlass", cutlass_method), ("triton", triton_method)]: - out = method.apply(layer, x, gate).cast("float32").numpy() - np.testing.assert_allclose( - out, - np.zeros_like(out), - atol=1e-6, - err_msg=f"{name}: zero input should produce zero output", - ) + out = triton_method.apply(layer, x, gate).cast("float32").numpy() + np.testing.assert_allclose( + out, + np.zeros_like(out), + atol=1e-6, + err_msg="triton: zero input should produce zero output", + ) From 21923b0eec3bc3b8268844679b7f0b6d823d8bc9 Mon Sep 17 00:00:00 2001 From: xuanyuanminzheng Date: Sat, 9 May 2026 18:21:44 +0800 Subject: [PATCH 3/3] fix review question. --- .../model_executor/layers/moe/__init__.py | 4 +- .../layers/moe/fused_moe_triton_backend.py | 138 +++++++++--------- fastdeploy/model_executor/layers/moe/moe.py | 4 +- .../layers/moe/triton_moe_kernels.py | 25 ++-- tests/layers/test_fused_moe_triton_backend.py | 94 ++++++------ 5 files changed, 130 insertions(+), 135 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/__init__.py b/fastdeploy/model_executor/layers/moe/__init__.py index 77adc4844d5..7f2ded19cb6 100644 --- a/fastdeploy/model_executor/layers/moe/__init__.py +++ b/fastdeploy/model_executor/layers/moe/__init__.py @@ -17,7 +17,7 @@ CutlassW4AFP8MoEMethod, CutlassWeightOnlyMoEMethod, ) -from .fused_moe_triton_backend import TritonBF16MoEMethod, TritonWeightOnlyMoEMethod +from .fused_moe_triton_backend import TritonMoEMethod, TritonWeightOnlyMoEMethod from .moe import FusedMoE __all__ = [ @@ -26,5 +26,5 @@ CutlassW4AFP8MoEMethod, FusedMoE, TritonWeightOnlyMoEMethod, - TritonBF16MoEMethod, + TritonMoEMethod, ] diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index ba6a0cfbb97..301e99f307c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -46,6 +46,7 @@ from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase +from .fused_moe_backend_base import UnquantizedFusedMoEMethod class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -1886,7 +1887,7 @@ def apply( ) -class TritonBF16MoEMethod(QuantMethodBase): +class TritonMoEMethod(UnquantizedFusedMoEMethod): """ Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. @@ -1896,32 +1897,7 @@ class TritonBF16MoEMethod(QuantMethodBase): """ def __init__(self, quant_config=None): - self.quant_config = quant_config - self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] - - def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None: - pass - - def create_weights(self, layer: nn.Layer, **extra_weight_attrs): - """ - Reuse UnquantizedFusedMoEMethod weight creation logic. - Weight shapes on CUDA (non-torch format): - up_gate_proj_weight: [E, hidden_size, moe_intermediate_size * 2] (K-major) - down_proj_weight: [E, moe_intermediate_size, hidden_size] (K-major) - The Triton kernel reads B as [E, K, N] which maps directly to these shapes. - """ - from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( - UnquantizedFusedMoEMethod, - ) - - UnquantizedFusedMoEMethod.create_weights(self, layer, **extra_weight_attrs) - - def process_weights_after_loading(self, layer: nn.Layer): - from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( - UnquantizedFusedMoEMethod, - ) - - UnquantizedFusedMoEMethod.process_weights_after_loading(self, layer) + super().__init__(quant_config) def process_loaded_weights(self, layer: nn.Layer, state_dict): """Stack individual expert weights into the stacked parameter.""" @@ -1929,33 +1905,45 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) - def _get_default_config(self, M: int, N: int, K: int) -> dict: + def _get_default_config(self, M: int, N: int, K: int, num_experts: int = 64) -> dict: """ - Heuristic tile config for BF16 MoE, mirroring vLLM's get_default_config logic. - M: number of token-expert pairs (post-padded) / BLOCK_SIZE_M + Heuristic tile config for BF16 MoE, aligned with vLLM's get_default_config logic. + M: number of token-expert pairs N: output dimension of the GEMM K: input dimension of the GEMM + num_experts: number of local experts (for GROUP_SIZE_M heuristic) """ if M <= 32: - block_m, block_n, block_k = 16, 64, 64 + block_m, block_n, block_k = 16, 64, 128 + num_warps, num_stages = 4, 4 + elif M <= 96: + block_m, block_n, block_k = 32, 64, 128 + num_warps, num_stages = 4, 3 elif M <= 512: - block_m, block_n, block_k = 32, 128, 64 + block_m, block_n, block_k = 64, 128, 64 + num_warps, num_stages = 8, 3 else: block_m, block_n, block_k = 128, 128, 64 + num_warps, num_stages = 8, 3 + + tokens_per_expert = M // max(num_experts, 1) + group_m = 16 if tokens_per_expert > 128 else 1 + return { "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": 8, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, } - def apply( + def apply_tp( self, layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, - shared_experts: nn.Layer = None, fc1_latent_proj: nn.Layer = None, fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: @@ -1971,9 +1959,6 @@ def apply( (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) 6. Reshape + sum over topk dim """ - if shared_experts is not None: - raise NotImplementedError("TritonBF16MoEMethod does not support shared_experts yet.") - token_num = x.shape[0] if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) @@ -2025,7 +2010,10 @@ def apply( # from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func num_token_expert_pairs = token_num * top_k - cfg = self._get_default_config(num_token_expert_pairs, moe_intermediate_size * 2, hidden_size) + # Use token_num (not pairs) for config selection, matching vLLM's heuristic: + # M represents "how many unique tokens each expert sees on average", which + # determines whether the workload is memory-bound (decode) or compute-bound (prefill). + cfg = self._get_default_config(token_num, moe_intermediate_size * 2, hidden_size, num_local_experts) # Use naive_block_assignment when token count is very small (decode scenario). # Each M-block handles exactly one token-expert pair, skipping the expensive @@ -2045,6 +2033,12 @@ def apply( topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] ) max_possible_num_post_padded = sorted_token_ids.shape[0] + # Grid clipping: avoid launching blocks that will immediately early-return + if token_num < cfg["BLOCK_SIZE_M"]: + max_possible_num_post_padded = min( + max_possible_num_post_padded, + token_num * top_k * cfg["BLOCK_SIZE_M"], + ) # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn @@ -2083,44 +2077,48 @@ def apply( top_k=top_k, compute_type=tl.bfloat16, naive_block_assignment=use_naive, + even_Ks=(hidden_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], ) # --- 4. SwiGLU activation --- down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- + # Kernel loads topk_weights with flat offset (topk_weights_ptr + offs_token), + # which assumes contiguous row-major layout (stride[-1] == 1). + if not topk_weights.is_contiguous(): + topk_weights = topk_weights.contiguous() + # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn down_proj_out = paddle.empty( (num_token_expert_pairs, hidden_size), dtype=x.dtype, ) - cfg2 = self._get_default_config(num_token_expert_pairs, hidden_size, moe_intermediate_size) - - # GEMM2 in naive mode: down_proj_input is [num_token_expert_pairs, N], each row - # is already a flat token-expert pair, so we reuse the same naive assignment. - # However GEMM2 needs its own preprocess if BLOCK_SIZE_M differs from cfg. + # Reuse the same config and preprocess results as GEMM1. + # The preprocess output only depends on BLOCK_SIZE_M (the M-tile alignment), + # which is determined solely by token_num and is identical for both GEMMs. + # This matches vLLM's approach of using one config for both GEMMs. if use_naive: - max_possible_num_post_padded_2 = num_token_expert_pairs * cfg2["BLOCK_SIZE_M"] + max_possible_num_post_padded_2 = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] num_tokens_post_padded_2 = paddle.full([1], max_possible_num_post_padded_2, dtype="int32") - # For GEMM2, expert_ids per pair is the same; topk_ids reshaped is still valid. expert_ids_2 = expert_ids sorted_token_ids_2 = expert_ids else: - # Standard path may need different preprocess if BLOCK_SIZE_M differs - if cfg2["BLOCK_SIZE_M"] != cfg["BLOCK_SIZE_M"]: - sorted_token_ids_2, expert_ids_2, num_tokens_post_padded_2 = tritonmoe_preprocess_func( - topk_ids, num_local_experts, cfg2["BLOCK_SIZE_M"] + sorted_token_ids_2 = sorted_token_ids + expert_ids_2 = expert_ids + num_tokens_post_padded_2 = num_tokens_post_padded + max_possible_num_post_padded_2 = max_possible_num_post_padded + # Grid clipping for GEMM2 + if token_num < cfg["BLOCK_SIZE_M"]: + max_possible_num_post_padded_2 = min( + max_possible_num_post_padded_2, + token_num * top_k * cfg["BLOCK_SIZE_M"], ) - max_possible_num_post_padded_2 = sorted_token_ids_2.shape[0] - else: - sorted_token_ids_2 = sorted_token_ids - expert_ids_2 = expert_ids - num_tokens_post_padded_2 = num_tokens_post_padded - max_possible_num_post_padded_2 = max_possible_num_post_padded grid2 = ( - ceil_div(max_possible_num_post_padded_2, cfg2["BLOCK_SIZE_M"]) - * ceil_div(hidden_size, cfg2["BLOCK_SIZE_N"]), + ceil_div(max_possible_num_post_padded_2, cfg["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg["BLOCK_SIZE_N"]), ) fused_moe_kernel_bf16[grid2]( down_proj_input, @@ -2141,14 +2139,17 @@ def apply( stride_bn=layer.down_proj_weight.strides[2], stride_cm=down_proj_out.strides[0], stride_cn=down_proj_out.strides[1], - BLOCK_SIZE_M=cfg2["BLOCK_SIZE_M"], - BLOCK_SIZE_N=cfg2["BLOCK_SIZE_N"], - BLOCK_SIZE_K=cfg2["BLOCK_SIZE_K"], - GROUP_SIZE_M=cfg2["GROUP_SIZE_M"], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], MUL_ROUTED_WEIGHT=True, # fuse router weight * output top_k=1, compute_type=tl.bfloat16, naive_block_assignment=use_naive, + even_Ks=(moe_intermediate_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], ) # --- 6. Reduce over topk --- @@ -2156,13 +2157,12 @@ def apply( out = down_proj_out.sum(axis=1) return out - def apply_ep_prefill(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): - raise NotImplementedError("TritonBF16MoEMethod does not support EP prefill yet.") - - def apply_ep_decode(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): - raise NotImplementedError("TritonBF16MoEMethod does not support EP decode yet.") + def apply_ep_prefill( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.") - def apply_tp( + def apply_ep_decode( self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None ): - return self.apply(layer, x, gate, topk_ids_hookfunc, shared_experts, fc1_latent_proj, fc2_latent_proj) + raise NotImplementedError("TritonMoEMethod does not support EP decode yet.") diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 1c580d9233e..b65d920e060 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -56,9 +56,9 @@ def get_moe_method(layer=None): if current_platform.is_cuda(): moe_backend = envs.FD_MOE_BACKEND.lower() if moe_backend == "triton": - from .fused_moe_triton_backend import TritonBF16MoEMethod + from .fused_moe_triton_backend import TritonMoEMethod - return TritonBF16MoEMethod(None) + return TritonMoEMethod(None) from .fused_moe_cutlass_backend import CutlassMoEMethod return CutlassMoEMethod(None) diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index f0a582927e8..66bc3507b32 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -243,6 +243,7 @@ def fused_moe_kernel_bf16( top_k: tl.constexpr, compute_type: tl.constexpr, naive_block_assignment: tl.constexpr = False, + even_Ks: tl.constexpr = False, ): """ BF16 Fused-MoE GEMM kernel, ported from vLLM. @@ -304,16 +305,20 @@ def fused_moe_kernel_bf16( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0, - ) + if even_Ks: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + ) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index a560a0398ea..3d2107bfa2d 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -895,8 +895,8 @@ class DummyTL: float16 = "float16" -class TestTritonBF16MoEMethod: - """Unit tests for TritonBF16MoEMethod. +class TestTritonMoEMethod: + """Unit tests for TritonMoEMethod. Pattern mirrors TestFusedMoeTritonBackend: - DummyLayer / DummyGate / DummyFDConfig (reused from module top) @@ -923,7 +923,7 @@ def _make_layer(self, num_experts=2, hidden_size=8, intermediate_size=4, top_k=2 def _create_weights(self, method, layer): """Call create_weights with the mandatory kwargs that the real MoE layer supplies. - TritonBF16MoEMethod targets the CUDA non-torch weight layout: + TritonMoEMethod targets the CUDA non-torch weight layout: up_gate_proj_weight: [E, hidden_size, inter*2] (K-major) down_proj_weight: [E, inter, hidden_size] (K-major) Therefore we must NOT pass model_format="torch"; any non-"torch" value @@ -950,13 +950,13 @@ def _patch_bf16_kernel(self, monkeypatch): # ------------------------------------------------------------------ def test_init_sets_weight_attrs(self): - """TritonBF16MoEMethod.__init__ must expose the two weight attr names.""" - method = backend.TritonBF16MoEMethod() + """TritonMoEMethod.__init__ must expose the two weight attr names.""" + method = backend.TritonMoEMethod() assert "up_gate_proj_weight" in method.added_weight_attrs assert "down_proj_weight" in method.added_weight_attrs def test_init_none_quant_config(self): - method = backend.TritonBF16MoEMethod(quant_config=None) + method = backend.TritonMoEMethod(quant_config=None) assert method.quant_config is None # ------------------------------------------------------------------ @@ -965,7 +965,7 @@ def test_init_none_quant_config(self): def test_create_weights_registers_parameters(self): """After create_weights the layer should have up_gate_proj_weight and down_proj_weight.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer() self._create_weights(method, layer) assert hasattr(layer, "up_gate_proj_weight") @@ -974,7 +974,7 @@ def test_create_weights_registers_parameters(self): def test_create_weights_shapes(self): """Weight tensors must have the correct [E, K, N] / [E, N, K] layout.""" E, H, N = 3, 8, 4 - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) self._create_weights(method, layer) # up_gate: [E, hidden_size, intermediate*2] @@ -989,7 +989,7 @@ def test_create_weights_shapes(self): def test_process_loaded_weights_stacks_experts(self): """process_loaded_weights must stack per-expert tensors into the stacked param.""" E, H, N = 2, 8, 4 - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) self._create_weights(method, layer) @@ -1018,39 +1018,29 @@ def test_process_loaded_weights_stacks_experts(self): ), f"Expert {i} down_proj weight mean={actual_down}, expected {expected_down}" # ------------------------------------------------------------------ - # process_prequanted_weights - # ------------------------------------------------------------------ - - def test_process_prequanted_weights_is_noop(self): - """process_prequanted_weights should return None (no-op for BF16).""" - method = backend.TritonBF16MoEMethod() - layer = self._make_layer() - result = method.process_prequanted_weights(layer, state_dict={}) - assert result is None - # ------------------------------------------------------------------ # _get_default_config — tile heuristic # ------------------------------------------------------------------ def test_get_default_config_decode(self): - """M<=32 decode path → 16x64x64.""" - method = backend.TritonBF16MoEMethod() + """M<=32 decode path → 16x64x128.""" + method = backend.TritonMoEMethod() cfg = method._get_default_config(M=4, N=128, K=128) assert cfg["BLOCK_SIZE_M"] == 16 assert cfg["BLOCK_SIZE_N"] == 64 - assert cfg["BLOCK_SIZE_K"] == 64 + assert cfg["BLOCK_SIZE_K"] == 128 def test_get_default_config_mid(self): - """32 < M <= 512 mid path → 32x128x64.""" - method = backend.TritonBF16MoEMethod() + """96 < M <= 512 mid path → 64x128x64.""" + method = backend.TritonMoEMethod() cfg = method._get_default_config(M=128, N=256, K=128) - assert cfg["BLOCK_SIZE_M"] == 32 + assert cfg["BLOCK_SIZE_M"] == 64 assert cfg["BLOCK_SIZE_N"] == 128 assert cfg["BLOCK_SIZE_K"] == 64 def test_get_default_config_prefill(self): """M > 512 prefill path → 128x128x64.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() cfg = method._get_default_config(M=1024, N=256, K=128) assert cfg["BLOCK_SIZE_M"] == 128 assert cfg["BLOCK_SIZE_N"] == 128 @@ -1058,19 +1048,19 @@ def test_get_default_config_prefill(self): def test_get_default_config_boundary_32(self): """M==32 is decode (<=32).""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() cfg = method._get_default_config(M=32, N=64, K=64) assert cfg["BLOCK_SIZE_M"] == 16 def test_get_default_config_boundary_512(self): - """M==512 is mid (<=512).""" - method = backend.TritonBF16MoEMethod() + """M==512 is mid (<=512) → BLOCK_SIZE_M=64.""" + method = backend.TritonMoEMethod() cfg = method._get_default_config(M=512, N=64, K=64) - assert cfg["BLOCK_SIZE_M"] == 32 + assert cfg["BLOCK_SIZE_M"] == 64 def test_get_default_config_has_group_size_m(self): """All configs must include GROUP_SIZE_M key.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() for M in (1, 64, 1024): cfg = method._get_default_config(M=M, N=64, K=64) assert "GROUP_SIZE_M" in cfg @@ -1081,7 +1071,7 @@ def test_get_default_config_has_group_size_m(self): def test_apply_empty_batch_returns_zero_tensor(self, fake_ops, monkeypatch): """apply() with 0 tokens must return a zero tensor of shape [0, hidden_size].""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=8) self._create_weights(method, layer) self._patch_bf16_kernel(monkeypatch) @@ -1099,7 +1089,7 @@ def test_apply_empty_batch_returns_zero_tensor(self, fake_ops, monkeypatch): def test_apply_noaux_tc_output_shape(self, fake_ops, monkeypatch): """apply() noaux_tc path: output shape must be [token_num, hidden_size].""" T, H = 4, 8 - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=H) self._create_weights(method, layer) self._patch_bf16_kernel(monkeypatch) @@ -1112,7 +1102,7 @@ def test_apply_noaux_tc_output_shape(self, fake_ops, monkeypatch): def test_apply_noaux_tc_topk_hook_called(self, fake_ops, monkeypatch): """topk_ids_hookfunc must be called with topk_ids kwarg during apply().""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=8) self._create_weights(method, layer) self._patch_bf16_kernel(monkeypatch) @@ -1129,7 +1119,7 @@ def hook(topk_ids): def test_apply_noaux_tc_kernel_called_twice(self, fake_ops, monkeypatch): """fused_moe_kernel_bf16 must be launched twice (GEMM1 + GEMM2) per forward pass.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=8) self._create_weights(method, layer) kernel = self._patch_bf16_kernel(monkeypatch) @@ -1145,7 +1135,7 @@ def test_apply_noaux_tc_kernel_called_twice(self, fake_ops, monkeypatch): def test_apply_aux_routing_path(self, fake_ops, monkeypatch): """When topk_method != 'noaux_tc', the moe_topk_select path is used.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=8) layer.topk_method = "aux" self._create_weights(method, layer) @@ -1168,7 +1158,7 @@ def hook(topk_ids): def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): """apply_tp() must produce the same output shape as apply().""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=8) self._create_weights(method, layer) self._patch_bf16_kernel(monkeypatch) @@ -1184,13 +1174,13 @@ def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): # ------------------------------------------------------------------ def test_apply_ep_prefill_raises(self): - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer() with pytest.raises(NotImplementedError): method.apply_ep_prefill(layer, None, None) def test_apply_ep_decode_raises(self): - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer() with pytest.raises(NotImplementedError): method.apply_ep_decode(layer, None, None) @@ -1205,7 +1195,7 @@ def test_naive_block_assignment_triggered(self, fake_ops, monkeypatch): With 256 experts, top_k=8, token_num=1: pairs=8, 8*4=32 <= 256 → naive. Verify that tritonmoe_preprocess_func is NOT called. """ - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=256, hidden_size=8, intermediate_size=4, top_k=8) self._create_weights(method, layer) kernel = self._patch_bf16_kernel(monkeypatch) @@ -1233,7 +1223,7 @@ def tracking_preprocess(topk_ids, num_local_experts, block_size): def test_naive_block_assignment_kernel_kwargs(self, fake_ops, monkeypatch): """In naive mode, kernel must be called with naive_block_assignment=True.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=64, hidden_size=8, intermediate_size=4, top_k=2) # pairs = 1*2 = 2, 2*4=8 <= 64 → naive self._create_weights(method, layer) @@ -1252,7 +1242,7 @@ def test_naive_block_assignment_kernel_kwargs(self, fake_ops, monkeypatch): def test_naive_block_assignment_standard_kernel_kwargs(self, fake_ops, monkeypatch): """In standard mode, kernel must be called with naive_block_assignment=False.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=4, hidden_size=8, intermediate_size=4, top_k=2) # pairs = 4*2=8, 8*4=32 > 4 → standard self._create_weights(method, layer) @@ -1270,7 +1260,7 @@ def test_naive_block_assignment_standard_kernel_kwargs(self, fake_ops, monkeypat def test_naive_block_assignment_grid_size(self, fake_ops, monkeypatch): """In naive mode, grid should be much smaller (num_pairs * cdiv(N, BLOCK_N)).""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=256, hidden_size=64, intermediate_size=32, top_k=8) self._create_weights(method, layer) kernel = self._patch_bf16_kernel(monkeypatch) @@ -1280,7 +1270,7 @@ def test_naive_block_assignment_grid_size(self, fake_ops, monkeypatch): method.apply(layer, x, gate) # token_num=1, top_k=8 → num_pairs=8 - # cfg for M=8: BLOCK_SIZE_M=16, BLOCK_SIZE_N=64 + # cfg for M=token_num=1: BLOCK_SIZE_M=16, BLOCK_SIZE_N=64 # naive: EM = 8 * 16 = 128, grid_M = cdiv(128,16) = 8 # GEMM1: N=intermediate*2=64, grid_N = cdiv(64,64) = 1 # grid1 = 8 * 1 = 8 @@ -1289,7 +1279,7 @@ def test_naive_block_assignment_grid_size(self, fake_ops, monkeypatch): def test_naive_block_assignment_expert_ids_content(self, fake_ops, monkeypatch): """In naive mode, expert_ids passed to kernel should be topk_ids.flatten().""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=64, hidden_size=8, intermediate_size=4, top_k=2) self._create_weights(method, layer) @@ -1318,7 +1308,7 @@ def patched_get_moe_scores(*args, **kwargs): def test_naive_boundary_exact(self, fake_ops, monkeypatch): """Test exact boundary: num_pairs * 4 == num_experts → naive IS triggered.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() # 32 experts, top_k=2, token_num=4 → pairs=8, 8*4=32 == 32 → naive layer = self._make_layer(num_experts=32, hidden_size=8, intermediate_size=4, top_k=2) self._create_weights(method, layer) @@ -1347,7 +1337,7 @@ def tracking_preprocess(topk_ids, num_local_experts, block_size): def test_naive_boundary_just_above(self, fake_ops, monkeypatch): """Test just above boundary: num_pairs * 4 > num_experts → standard path.""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() # 31 experts, top_k=2, token_num=4 → pairs=8, 8*4=32 > 31 → standard layer = self._make_layer(num_experts=31, hidden_size=8, intermediate_size=4, top_k=2) self._create_weights(method, layer) @@ -1376,7 +1366,7 @@ def tracking_preprocess(topk_ids, num_local_experts, block_size): def test_naive_single_token_output_shape(self, fake_ops, monkeypatch): """Single token decode scenario (common case for naive path).""" - method = backend.TritonBF16MoEMethod() + method = backend.TritonMoEMethod() layer = self._make_layer(num_experts=128, hidden_size=16, intermediate_size=8, top_k=8) self._create_weights(method, layer) self._patch_bf16_kernel(monkeypatch) @@ -1389,13 +1379,13 @@ def test_naive_single_token_output_shape(self, fake_ops, monkeypatch): # =========================================================================== -# Precision tests: TritonBF16MoEMethod vs. CutlassMoEMethod (BF16) +# Precision tests: TritonMoEMethod vs. CutlassMoEMethod (BF16) # =========================================================================== def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_k): """ - Build a DummyLayer with random BF16 weights and a TritonBF16MoEMethod. + Build a DummyLayer with random BF16 weights and a TritonMoEMethod. Weight layout (CUDA non-torch): [E, H, 2N] for up_gate_proj, [E, N, H] for down_proj. Returns (layer, None, triton_method) for compatibility with existing test signatures. @@ -1409,7 +1399,7 @@ def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_ weight_dtype="bfloat16", ) - triton_method = backend.TritonBF16MoEMethod() + triton_method = backend.TritonMoEMethod() # Create weight parameters (CUDA non-torch layout) triton_method.create_weights( @@ -1458,7 +1448,7 @@ def forward(self, x): @pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA") # @pytest.mark.skipif(not _triton_ops_available(), reason="triton MoE ops not available (custom ops not compiled)") -class TestTritonBF16MoEPrecision: +class TestTritonMoEPrecision: """ Precision tests: Triton BF16 path vs. Cutlass BF16 path.