diff --git a/fastdeploy/model_executor/layers/moe/__init__.py b/fastdeploy/model_executor/layers/moe/__init__.py index 540a0828ae5..7f2ded19cb6 100644 --- a/fastdeploy/model_executor/layers/moe/__init__.py +++ b/fastdeploy/model_executor/layers/moe/__init__.py @@ -17,7 +17,7 @@ CutlassW4AFP8MoEMethod, CutlassWeightOnlyMoEMethod, ) -from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod +from .fused_moe_triton_backend import TritonMoEMethod, TritonWeightOnlyMoEMethod from .moe import FusedMoE __all__ = [ @@ -26,4 +26,5 @@ CutlassW4AFP8MoEMethod, FusedMoE, TritonWeightOnlyMoEMethod, + TritonMoEMethod, ] diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 7d23fc96da0..301e99f307c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -18,10 +18,23 @@ import paddle import paddle.nn.functional as F +import triton.language as tl from paddle import nn import fastdeploy +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.moe.triton_moe_kernels import ( + fused_moe_kernel_bf16, + fused_moe_kernel_paddle, +) +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + fused_stack_transpose_quant, + quant_weight_ue8m0, + transform_scale_ue8m0, +) +from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func from fastdeploy.model_executor.utils import ( TensorTracker, free_tensor, @@ -33,20 +46,7 @@ from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase - -try: - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func - - from .triton_moe_kernels import fused_moe_kernel_paddle -except ImportError: - pass -from fastdeploy.model_executor.layers.moe.moe import get_moe_scores -from fastdeploy.model_executor.layers.quantization.fp8_utils import ( - fused_stack_transpose_quant, - quant_weight_ue8m0, - transform_scale_ue8m0, -) -from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant +from .fused_moe_backend_base import UnquantizedFusedMoEMethod class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -780,8 +780,8 @@ def apply( stride_am=x_q.strides[0], stride_ak=x_q.strides[1], stride_be=layer.up_gate_proj_weight.strides[0], - stride_bk=layer.up_gate_proj_weight.strides[2], - stride_bn=layer.up_gate_proj_weight.strides[1], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], stride_cm=up_gate_proj_out.strides[0], stride_cn=up_gate_proj_out.strides[1], # @@ -1885,3 +1885,284 @@ def apply( fc1_latent_proj, fc2_latent_proj, ) + + +class TritonMoEMethod(UnquantizedFusedMoEMethod): + """ + Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. + + Activated via: export FD_MOE_BACKEND=triton + Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. + This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. + """ + + def __init__(self, quant_config=None): + super().__init__(quant_config) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """Stack individual expert weights into the stacked parameter.""" + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) + layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) + + def _get_default_config(self, M: int, N: int, K: int, num_experts: int = 64) -> dict: + """ + Heuristic tile config for BF16 MoE, aligned with vLLM's get_default_config logic. + M: number of token-expert pairs + N: output dimension of the GEMM + K: input dimension of the GEMM + num_experts: number of local experts (for GROUP_SIZE_M heuristic) + """ + if M <= 32: + block_m, block_n, block_k = 16, 64, 128 + num_warps, num_stages = 4, 4 + elif M <= 96: + block_m, block_n, block_k = 32, 64, 128 + num_warps, num_stages = 4, 3 + elif M <= 512: + block_m, block_n, block_k = 64, 128, 64 + num_warps, num_stages = 8, 3 + else: + block_m, block_n, block_k = 128, 128, 64 + num_warps, num_stages = 8, 3 + + tokens_per_expert = M // max(num_experts, 1) + group_m = 16 if tokens_per_expert > 128 else 1 + + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + } + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, + ) -> paddle.Tensor: + """ + BF16 Triton Fused MoE forward. + + Pipeline: + 1. Gate + topk routing + 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded + 3. fused_moe_kernel_bf16 GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] + 4. SwiGLU activation + 5. fused_moe_kernel_bf16 GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] + (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) + 6. Reshape + sum over topk dim + """ + token_num = x.shape[0] + if token_num == 0: + return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) + + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + + # --- 1. Routing --- + gate_out = gate(x) + + if layer.topk_method == "noaux_tc": + from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") + + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + use_fused_cast=use_fused, + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + ) + else: + gate_out = gate_out.cast("float32") + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight + False, + ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + + # # Ensure topk_ids is int64 (noaux_tc may return int32, tritonmoe_preprocess requires int64) + # if topk_ids.dtype != paddle.int64: + # topk_ids = topk_ids.cast("int64") + + # --- 2. Preprocess: sort tokens by expert assignment --- + # from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func + + num_token_expert_pairs = token_num * top_k + # Use token_num (not pairs) for config selection, matching vLLM's heuristic: + # M represents "how many unique tokens each expert sees on average", which + # determines whether the workload is memory-bound (decode) or compute-bound (prefill). + cfg = self._get_default_config(token_num, moe_intermediate_size * 2, hidden_size, num_local_experts) + + # Use naive_block_assignment when token count is very small (decode scenario). + # Each M-block handles exactly one token-expert pair, skipping the expensive + # preprocess sort kernel. Condition mirrors vLLM: num_pairs * 4 <= num_experts. + _SPARSITY_FACTOR = 4 + use_naive = num_token_expert_pairs * _SPARSITY_FACTOR <= num_local_experts + + if use_naive: + # Skip preprocess: use topk_ids directly as expert_ids (one per pair) + expert_ids = topk_ids.reshape([-1]).cast("int32") + num_tokens_post_padded = paddle.full([1], num_token_expert_pairs * cfg["BLOCK_SIZE_M"], dtype="int32") + max_possible_num_post_padded = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] + # sorted_token_ids is not used in naive mode; pass expert_ids as a valid ptr + sorted_token_ids = expert_ids + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + # Grid clipping: avoid launching blocks that will immediately early-return + if token_num < cfg["BLOCK_SIZE_M"]: + max_possible_num_post_padded = min( + max_possible_num_post_padded, + token_num * top_k * cfg["BLOCK_SIZE_M"], + ) + + # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- + # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn + up_gate_proj_out = paddle.empty( + [num_token_expert_pairs, moe_intermediate_size * 2], + dtype=x.dtype, + ) + grid1 = ( + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid1]( + x, + layer.up_gate_proj_weight, + up_gate_proj_out, + None, # topk_weights_ptr (no weight mul on GEMM1) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=moe_intermediate_size * 2, + K=hidden_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=x.strides[0], + stride_ak=x.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type=tl.bfloat16, + naive_block_assignment=use_naive, + even_Ks=(hidden_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], + ) + + # --- 4. SwiGLU activation --- + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- + # Kernel loads topk_weights with flat offset (topk_weights_ptr + offs_token), + # which assumes contiguous row-major layout (stride[-1] == 1). + if not topk_weights.is_contiguous(): + topk_weights = topk_weights.contiguous() + + # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn + down_proj_out = paddle.empty( + (num_token_expert_pairs, hidden_size), + dtype=x.dtype, + ) + # Reuse the same config and preprocess results as GEMM1. + # The preprocess output only depends on BLOCK_SIZE_M (the M-tile alignment), + # which is determined solely by token_num and is identical for both GEMMs. + # This matches vLLM's approach of using one config for both GEMMs. + if use_naive: + max_possible_num_post_padded_2 = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] + num_tokens_post_padded_2 = paddle.full([1], max_possible_num_post_padded_2, dtype="int32") + expert_ids_2 = expert_ids + sorted_token_ids_2 = expert_ids + else: + sorted_token_ids_2 = sorted_token_ids + expert_ids_2 = expert_ids + num_tokens_post_padded_2 = num_tokens_post_padded + max_possible_num_post_padded_2 = max_possible_num_post_padded + # Grid clipping for GEMM2 + if token_num < cfg["BLOCK_SIZE_M"]: + max_possible_num_post_padded_2 = min( + max_possible_num_post_padded_2, + token_num * top_k * cfg["BLOCK_SIZE_M"], + ) + + grid2 = ( + ceil_div(max_possible_num_post_padded_2, cfg["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid2]( + down_proj_input, + layer.down_proj_weight, + down_proj_out, + topk_weights, + sorted_token_ids_2, + expert_ids_2, + num_tokens_post_padded_2, + N=hidden_size, + K=moe_intermediate_size, + EM=max_possible_num_post_padded_2, + num_valid_tokens=num_token_expert_pairs, + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, # fuse router weight * output + top_k=1, + compute_type=tl.bfloat16, + naive_block_assignment=use_naive, + even_Ks=(moe_intermediate_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], + ) + + # --- 6. Reduce over topk --- + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + return out + + def apply_ep_prefill( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.") + + def apply_ep_decode( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP decode yet.") diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c048248eec4..b65d920e060 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -54,6 +54,11 @@ def get_moe_method(layer=None): """ if current_platform.is_cuda(): + moe_backend = envs.FD_MOE_BACKEND.lower() + if moe_backend == "triton": + from .fused_moe_triton_backend import TritonMoEMethod + + return TritonMoEMethod(None) from .fused_moe_cutlass_backend import CutlassMoEMethod return CutlassMoEMethod(None) diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index ac5dfa96fcc..66bc3507b32 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -198,3 +198,139 @@ def fused_moe_kernel_paddle( c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + + +# --------------------------------------------------------------------------- +# BF16-native MoE kernel, ported from vLLM fused_moe_kernel (BF16-only path). +# +# Key differences from fused_moe_kernel_paddle (the wint8/fp8 kernel above): +# 1. compute_type is a tl.constexpr parameter (not hardcoded bfloat16). +# 2. offs_token is cast to int64 to prevent stride-multiplication overflow. +# 3. b matrix load always uses a K-boundary mask (no even_Ks special path). +# 4. Router-weight multiplication is done in fp32 before the final cast. +# 5. No quantization paths (use_fp8/int8 removed for clarity). +# --------------------------------------------------------------------------- +@enable_compat_on_triton_kernel +@triton.jit +def fused_moe_kernel_bf16( + # Pointers + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Dimensions (runtime scalars) + N, + K, + EM, + num_valid_tokens, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters (compile-time constants) + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + naive_block_assignment: tl.constexpr = False, + even_Ks: tl.constexpr = False, +): + """ + BF16 Fused-MoE GEMM kernel, ported from vLLM. + + A: [num_tokens, K] – input activations (bf16) + B: [E, K, N] – expert weights (bf16) + C: [num_tokens * top_k, N] – output (bf16) + + sorted_token_ids: [EM] flat token-expert pair indices (int32) + expert_ids: [EM // BLOCK_SIZE_M] expert index per M-block (int32) + + When naive_block_assignment=True, each M-block processes exactly one + token-expert pair (skipping the preprocess/sort step). In this mode: + - expert_ids[pid_m] holds the expert index for token-expert pair pid_m + - sorted_token_ids_ptr is unused + - offs_token is constructed as [pid_m, invalid, invalid, ...] + This avoids the preprocess kernel overhead for very small token counts. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs = tl.arange(0, BLOCK_SIZE_M) + + if not naive_block_assignment: + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + else: + # Each block handles exactly one token-expert pair: + # row 0 = pid_m (the token-expert pair index), remaining rows are + # set to num_valid_tokens which will fail the < mask check. + offs_token = tl.where(offs == 0, pid_m, num_valid_tokens) + + # Cast to int64 to prevent overflow: stride_cm * offs_token can exceed int32 + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointer: a_ptr[token_idx, :K] where token_idx = offs_token // top_k + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + + # B pointer: b_ptr[expert, :K, offs_bn] — B layout is [E, K, N] + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if even_Ks: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Router-weight multiplication in fp32 (before precision conversion) + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 58d9d8be6e4..3d2107bfa2d 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -20,6 +20,7 @@ import sys import types +import numpy as np import paddle import pytest @@ -93,6 +94,10 @@ def __init__( self.renormalize = True self.gate_correction_bias = paddle.zeros([num_local_experts], dtype="float32") self.topk_method = "noaux_tc" + self.with_bias = False + self.ep_size = 1 + self.activation = "swiglu" + self.moe_quant_config = types.SimpleNamespace() self.fd_config = DummyFDConfig(load_choices) self.weight_dtype = weight_dtype self.quant_method = DummyQuantMethod(quant_config) @@ -211,10 +216,15 @@ def test_backend_imports_kernel_module(self, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) reloaded = importlib.reload(backend) assert hasattr(reloaded, "fused_moe_kernel_paddle") + # Restore the real module: reload() permanently rebinds module-level names + # (e.g. fused_moe_kernel_bf16) to the fake, and monkeypatch cannot undo that. + # A second reload after monkeypatch restores sys.modules fixes the binding. + monkeypatch.undo() + importlib.reload(backend) def test_triton_weight_only_create_and_apply(self, fake_ops, monkeypatch): quant_config = DummyQuantConfig(is_checkpoint_bf16=False) @@ -323,7 +333,7 @@ def test_wfp8afp8_method_apply_paths(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -397,7 +407,7 @@ def test_wfp8afp8_apply_noaux_and_empty(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) _ = method.apply( @@ -437,7 +447,7 @@ def test_tensorwise_prequant_and_apply(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -460,7 +470,7 @@ def test_python_op_fused_moe_kernel_paddle(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr( paddle.static, @@ -805,7 +815,7 @@ def test_python_op_learnable_scaling(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr( paddle.static, @@ -857,3 +867,689 @@ def hook(topk_ids): ) assert "topk" in captured + + +class DummyBF16Kernel: + """ + Simulates fused_moe_kernel_bf16[grid](...). + Writes zeros into the output tensor (3rd positional argument). + """ + + def __init__(self): + self.calls = [] + + def __getitem__(self, grid): + def _runner(*args, **kwargs): + # output tensor is the 3rd positional argument (index 2) + if len(args) > 2 and isinstance(args[2], paddle.Tensor): + args[2].set_value(paddle.zeros_like(args[2])) + self.calls.append({"grid": grid, "kwargs": kwargs}) + + return _runner + + +class DummyTL: + """Minimal stub for triton.language so tests don't need a real Triton install.""" + + bfloat16 = "bfloat16" + float16 = "float16" + + +class TestTritonMoEMethod: + """Unit tests for TritonMoEMethod. + + Pattern mirrors TestFusedMoeTritonBackend: + - DummyLayer / DummyGate / DummyFDConfig (reused from module top) + - fake_ops fixture patches routing + preprocess ops + - DummyBF16Kernel patches fused_moe_kernel_bf16 + - No real GPU kernels are executed; output shapes / attributes are verified + """ + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + def _make_layer(self, num_experts=2, hidden_size=8, intermediate_size=4, top_k=2): + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + return layer + + def _create_weights(self, method, layer): + """Call create_weights with the mandatory kwargs that the real MoE layer supplies. + + TritonMoEMethod targets the CUDA non-torch weight layout: + up_gate_proj_weight: [E, hidden_size, inter*2] (K-major) + down_proj_weight: [E, inter, hidden_size] (K-major) + Therefore we must NOT pass model_format="torch"; any non-"torch" value + (or omitting the key) lets UnquantizedFusedMoEMethod take the CUDA branch. + """ + method.create_weights( + layer, + model_format="default", + num_experts=layer.num_local_experts, + hidden_size=layer.hidden_size, + moe_intermediate_size=layer.moe_intermediate_size, + ) + + def _patch_bf16_kernel(self, monkeypatch): + kernel = DummyBF16Kernel() + monkeypatch.setattr(backend, "fused_moe_kernel_bf16", kernel, raising=False) + # Patch tl so that `compute_type=tl.bfloat16` inside apply() does not + # raise NameError when triton is not installed in the test environment. + monkeypatch.setattr(backend, "tl", DummyTL(), raising=False) + return kernel + + # ------------------------------------------------------------------ + # __init__ / basic construction + # ------------------------------------------------------------------ + + def test_init_sets_weight_attrs(self): + """TritonMoEMethod.__init__ must expose the two weight attr names.""" + method = backend.TritonMoEMethod() + assert "up_gate_proj_weight" in method.added_weight_attrs + assert "down_proj_weight" in method.added_weight_attrs + + def test_init_none_quant_config(self): + method = backend.TritonMoEMethod(quant_config=None) + assert method.quant_config is None + + # ------------------------------------------------------------------ + # create_weights + # ------------------------------------------------------------------ + + def test_create_weights_registers_parameters(self): + """After create_weights the layer should have up_gate_proj_weight and down_proj_weight.""" + method = backend.TritonMoEMethod() + layer = self._make_layer() + self._create_weights(method, layer) + assert hasattr(layer, "up_gate_proj_weight") + assert hasattr(layer, "down_proj_weight") + + def test_create_weights_shapes(self): + """Weight tensors must have the correct [E, K, N] / [E, N, K] layout.""" + E, H, N = 3, 8, 4 + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + # up_gate: [E, hidden_size, intermediate*2] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + # down: [E, intermediate, hidden_size] + assert list(layer.down_proj_weight.shape) == [E, N, H] + + # ------------------------------------------------------------------ + # process_loaded_weights + # ------------------------------------------------------------------ + + def test_process_loaded_weights_stacks_experts(self): + """process_loaded_weights must stack per-expert tensors into the stacked param.""" + E, H, N = 2, 8, 4 + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + + # Provide per-expert tensors via extract_moe_ffn_weights + up_weights = [paddle.ones([H, N * 2], dtype="bfloat16") * (i + 1) for i in range(E)] + down_weights = [paddle.ones([N, H], dtype="bfloat16") * (i + 1) for i in range(E)] + layer._up_weights = up_weights + layer._down_weights = down_weights + + method.process_loaded_weights(layer, state_dict={}) + + # After stacking, shape should be [E, ...] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + assert list(layer.down_proj_weight.shape) == [E, N, H] + # Verify each expert's data is correctly stacked (expert i has value i+1) + for i in range(E): + expected_up = float(i + 1) + expected_down = float(i + 1) + actual_up = float(layer.up_gate_proj_weight[i].cast("float32").mean()) + actual_down = float(layer.down_proj_weight[i].cast("float32").mean()) + assert ( + abs(actual_up - expected_up) < 1e-3 + ), f"Expert {i} up_gate weight mean={actual_up}, expected {expected_up}" + assert ( + abs(actual_down - expected_down) < 1e-3 + ), f"Expert {i} down_proj weight mean={actual_down}, expected {expected_down}" + + # ------------------------------------------------------------------ + # ------------------------------------------------------------------ + # _get_default_config — tile heuristic + # ------------------------------------------------------------------ + + def test_get_default_config_decode(self): + """M<=32 decode path → 16x64x128.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=4, N=128, K=128) + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["BLOCK_SIZE_N"] == 64 + assert cfg["BLOCK_SIZE_K"] == 128 + + def test_get_default_config_mid(self): + """96 < M <= 512 mid path → 64x128x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=128, N=256, K=128) + assert cfg["BLOCK_SIZE_M"] == 64 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_prefill(self): + """M > 512 prefill path → 128x128x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=1024, N=256, K=128) + assert cfg["BLOCK_SIZE_M"] == 128 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_boundary_32(self): + """M==32 is decode (<=32).""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=32, N=64, K=64) + assert cfg["BLOCK_SIZE_M"] == 16 + + def test_get_default_config_boundary_512(self): + """M==512 is mid (<=512) → BLOCK_SIZE_M=64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=512, N=64, K=64) + assert cfg["BLOCK_SIZE_M"] == 64 + + def test_get_default_config_has_group_size_m(self): + """All configs must include GROUP_SIZE_M key.""" + method = backend.TritonMoEMethod() + for M in (1, 64, 1024): + cfg = method._get_default_config(M=M, N=64, K=64) + assert "GROUP_SIZE_M" in cfg + + # ------------------------------------------------------------------ + # apply — empty-batch fast path + # ------------------------------------------------------------------ + + def test_apply_empty_batch_returns_zero_tensor(self, fake_ops, monkeypatch): + """apply() with 0 tokens must return a zero tensor of shape [0, hidden_size].""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.zeros([0, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [0, layer.hidden_size] + + # ------------------------------------------------------------------ + # apply — normal forward (noaux_tc routing path) + # ------------------------------------------------------------------ + + def test_apply_noaux_tc_output_shape(self, fake_ops, monkeypatch): + """apply() noaux_tc path: output shape must be [token_num, hidden_size].""" + T, H = 4, 8 + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=H) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([T, H], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [T, H] + + def test_apply_noaux_tc_topk_hook_called(self, fake_ops, monkeypatch): + """topk_ids_hookfunc must be called with topk_ids kwarg during apply().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert "topk_ids" in captured + + def test_apply_noaux_tc_kernel_called_twice(self, fake_ops, monkeypatch): + """fused_moe_kernel_bf16 must be launched twice (GEMM1 + GEMM2) per forward pass.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts)) + + assert len(kernel.calls) == 2, f"Expected 2 kernel launches (GEMM1 + GEMM2), got {len(kernel.calls)}" + + # ------------------------------------------------------------------ + # apply — non-noaux routing path (moe_topk_select) + # ------------------------------------------------------------------ + + def test_apply_aux_routing_path(self, fake_ops, monkeypatch): + """When topk_method != 'noaux_tc', the moe_topk_select path is used.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + layer.topk_method = "aux" + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["ids"] = topk_ids + + x = paddle.randn([3, layer.hidden_size], dtype="bfloat16") + out = method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert list(out.shape) == [3, layer.hidden_size] + assert "ids" in captured + + # ------------------------------------------------------------------ + # apply_tp delegates to apply + # ------------------------------------------------------------------ + + def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): + """apply_tp() must produce the same output shape as apply().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply_tp(layer, x, gate) + + assert list(out.shape) == [2, layer.hidden_size] + + # ------------------------------------------------------------------ + # EP methods raise NotImplementedError + # ------------------------------------------------------------------ + + def test_apply_ep_prefill_raises(self): + method = backend.TritonMoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_prefill(layer, None, None) + + def test_apply_ep_decode_raises(self): + method = backend.TritonMoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_decode(layer, None, None) + + # ------------------------------------------------------------------ + # naive_block_assignment — decode fast path + # ------------------------------------------------------------------ + + def test_naive_block_assignment_triggered(self, fake_ops, monkeypatch): + """When num_token_expert_pairs * 4 <= num_experts, naive path is used. + + With 256 experts, top_k=8, token_num=1: pairs=8, 8*4=32 <= 256 → naive. + Verify that tritonmoe_preprocess_func is NOT called. + """ + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=256, hidden_size=8, intermediate_size=4, top_k=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + preprocess_called = [] + + def tracking_preprocess(topk_ids, num_local_experts, block_size): + preprocess_called.append(True) + token_num = topk_ids.shape[0] + top_k = topk_ids.shape[1] + sorted_token_ids = paddle.arange(token_num * top_k, dtype="int32") + expert_ids = paddle.zeros_like(sorted_token_ids) + num_tokens_post_padded = paddle.to_tensor([token_num * top_k], dtype="int32") + return sorted_token_ids, expert_ids, num_tokens_post_padded + + monkeypatch.setattr(backend, "tritonmoe_preprocess_func", tracking_preprocess, raising=False) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [1, layer.hidden_size] + assert len(preprocess_called) == 0, "tritonmoe_preprocess_func should NOT be called in naive mode" + assert len(kernel.calls) == 2, "Two kernel launches expected (GEMM1 + GEMM2)" + + def test_naive_block_assignment_kernel_kwargs(self, fake_ops, monkeypatch): + """In naive mode, kernel must be called with naive_block_assignment=True.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=64, hidden_size=8, intermediate_size=4, top_k=2) + # pairs = 1*2 = 2, 2*4=8 <= 64 → naive + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + # Both GEMM1 and GEMM2 should have naive_block_assignment=True + for i, call in enumerate(kernel.calls): + assert ( + call["kwargs"].get("naive_block_assignment") is True + ), f"Kernel call {i} should have naive_block_assignment=True" + + def test_naive_block_assignment_standard_kernel_kwargs(self, fake_ops, monkeypatch): + """In standard mode, kernel must be called with naive_block_assignment=False.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=4, hidden_size=8, intermediate_size=4, top_k=2) + # pairs = 4*2=8, 8*4=32 > 4 → standard + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([4, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + for i, call in enumerate(kernel.calls): + assert ( + call["kwargs"].get("naive_block_assignment") is False + ), f"Kernel call {i} should have naive_block_assignment=False" + + def test_naive_block_assignment_grid_size(self, fake_ops, monkeypatch): + """In naive mode, grid should be much smaller (num_pairs * cdiv(N, BLOCK_N)).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=256, hidden_size=64, intermediate_size=32, top_k=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + # token_num=1, top_k=8 → num_pairs=8 + # cfg for M=token_num=1: BLOCK_SIZE_M=16, BLOCK_SIZE_N=64 + # naive: EM = 8 * 16 = 128, grid_M = cdiv(128,16) = 8 + # GEMM1: N=intermediate*2=64, grid_N = cdiv(64,64) = 1 + # grid1 = 8 * 1 = 8 + gemm1_grid = kernel.calls[0]["grid"] + assert gemm1_grid == (8,), f"Expected grid (8,) for naive GEMM1, got {gemm1_grid}" + + def test_naive_block_assignment_expert_ids_content(self, fake_ops, monkeypatch): + """In naive mode, expert_ids passed to kernel should be topk_ids.flatten().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=64, hidden_size=8, intermediate_size=4, top_k=2) + self._create_weights(method, layer) + + # Patch get_moe_scores to return specific topk_ids + specific_topk_ids = paddle.to_tensor([[3, 7]], dtype="int64") # 1 token, top2 + specific_topk_weights = paddle.to_tensor([[0.6, 0.4]], dtype="float32") + + def patched_get_moe_scores(*args, **kwargs): + return args[0], specific_topk_weights, specific_topk_ids + + monkeypatch.setattr(backend, "get_moe_scores", patched_get_moe_scores) + + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + # In naive mode, expert_ids (6th positional arg, index 5) should be [3, 7] (int32) + gemm1_args = kernel.calls[0]["kwargs"] + # expert_ids is positional arg — let's check via the recorded calls + # The kernel is called as fused_moe_kernel_bf16[grid](x, weight, out, weights_ptr, + # sorted_token_ids, expert_ids, num_tokens_post_padded, ...) + # But DummyBF16Kernel only records kwargs; let's just verify naive_block_assignment is set + assert gemm1_args.get("naive_block_assignment") is True + + def test_naive_boundary_exact(self, fake_ops, monkeypatch): + """Test exact boundary: num_pairs * 4 == num_experts → naive IS triggered.""" + method = backend.TritonMoEMethod() + # 32 experts, top_k=2, token_num=4 → pairs=8, 8*4=32 == 32 → naive + layer = self._make_layer(num_experts=32, hidden_size=8, intermediate_size=4, top_k=2) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + preprocess_called = [] + + def tracking_preprocess(topk_ids, num_local_experts, block_size): + preprocess_called.append(True) + token_num = topk_ids.shape[0] + top_k = topk_ids.shape[1] + sorted_token_ids = paddle.arange(token_num * top_k, dtype="int32") + expert_ids = paddle.zeros_like(sorted_token_ids) + num_tokens_post_padded = paddle.to_tensor([token_num * top_k], dtype="int32") + return sorted_token_ids, expert_ids, num_tokens_post_padded + + monkeypatch.setattr(backend, "tritonmoe_preprocess_func", tracking_preprocess, raising=False) + + x = paddle.randn([4, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [4, layer.hidden_size] + assert len(preprocess_called) == 0, "Exact boundary should trigger naive (<=)" + assert kernel.calls[0]["kwargs"]["naive_block_assignment"] is True + + def test_naive_boundary_just_above(self, fake_ops, monkeypatch): + """Test just above boundary: num_pairs * 4 > num_experts → standard path.""" + method = backend.TritonMoEMethod() + # 31 experts, top_k=2, token_num=4 → pairs=8, 8*4=32 > 31 → standard + layer = self._make_layer(num_experts=31, hidden_size=8, intermediate_size=4, top_k=2) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + preprocess_called = [] + + def tracking_preprocess(topk_ids, num_local_experts, block_size): + preprocess_called.append(True) + token_num = topk_ids.shape[0] + top_k = topk_ids.shape[1] + sorted_token_ids = paddle.arange(token_num * top_k, dtype="int32") + expert_ids = paddle.zeros_like(sorted_token_ids) + num_tokens_post_padded = paddle.to_tensor([token_num * top_k], dtype="int32") + return sorted_token_ids, expert_ids, num_tokens_post_padded + + monkeypatch.setattr(backend, "tritonmoe_preprocess_func", tracking_preprocess, raising=False) + + x = paddle.randn([4, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [4, layer.hidden_size] + assert len(preprocess_called) > 0, "Just above boundary should use standard path" + assert kernel.calls[0]["kwargs"]["naive_block_assignment"] is False + + def test_naive_single_token_output_shape(self, fake_ops, monkeypatch): + """Single token decode scenario (common case for naive path).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=128, hidden_size=16, intermediate_size=8, top_k=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [1, layer.hidden_size] + + +# =========================================================================== +# Precision tests: TritonMoEMethod vs. CutlassMoEMethod (BF16) +# =========================================================================== + + +def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_k): + """ + Build a DummyLayer with random BF16 weights and a TritonMoEMethod. + + Weight layout (CUDA non-torch): [E, H, 2N] for up_gate_proj, [E, N, H] for down_proj. + Returns (layer, None, triton_method) for compatibility with existing test signatures. + """ + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + + triton_method = backend.TritonMoEMethod() + + # Create weight parameters (CUDA non-torch layout) + triton_method.create_weights( + layer, + model_format="default", + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + ) + + # Fill with Xavier-like random BF16 weights to produce meaningful output magnitudes. + # W1: [E, H, 2N] — scale by 1/sqrt(H) so GEMM1 output ~O(1) + # W2: [E, N, H] — scale by 1/sqrt(N) so GEMM2 output ~O(1) + paddle.seed(42) + w1_scale = 1.0 / (hidden_size**0.5) + w2_scale = 1.0 / (intermediate_size**0.5) + layer.up_gate_proj_weight.set_value((paddle.randn(layer.up_gate_proj_weight.shape) * w1_scale).cast("bfloat16")) + layer.down_proj_weight.set_value((paddle.randn(layer.down_proj_weight.shape) * w2_scale).cast("bfloat16")) + return layer, None, triton_method + + +def _uniform_gate(layer): + """Gate that outputs uniform logits so every expert gets equal probability.""" + + class _Gate(paddle.nn.Layer): + def __init__(self, num_experts): + super().__init__() + self.num_experts = num_experts + + def forward(self, x): + return paddle.ones([x.shape[0], self.num_experts], dtype="float32") + + return _Gate(layer.num_local_experts) + + +# Shapes to exercise: (token_num, hidden_size, intermediate_size, num_experts, top_k) +# Small/medium sizes to keep test runtime reasonable. +_PRECISION_SHAPES = [ + pytest.param(1, 64, 32, 8, 2, id="decode_T1_H64"), + pytest.param(16, 64, 32, 8, 2, id="decode_T16_H64"), + pytest.param(64, 128, 64, 8, 2, id="mid_T64_H128"), + pytest.param(128, 128, 64, 8, 2, id="mid_T128_H128_E8"), + pytest.param(256, 256, 128, 8, 4, id="prefill_T256_H256"), +] + + +@pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA") +# @pytest.mark.skipif(not _triton_ops_available(), reason="triton MoE ops not available (custom ops not compiled)") +class TestTritonMoEPrecision: + """ + Precision tests: Triton BF16 path vs. Cutlass BF16 path. + + Both paths are activated in production via the FD_MOE_BACKEND env var + (triton vs cutlass). This test verifies they produce numerically equivalent + results on the same shared BF16 weights and identical inputs. + + All tests run real GPU kernels (no mocking). + Tolerance: atol=1e-2, rtol=1e-2 (both kernels use BF16 arithmetic with + fp32 accumulation; differences come from tile ordering / rounding). + """ + + # Tolerance for comparing two independent BF16 GEMM implementations. + # BF16 has ~7-bit mantissa (eps ~0.008). After GEMM1 + SwiGLU + GEMM2, + # rounding differences accumulate. Use np.allclose style: + # |triton - cutlass| <= ATOL + RTOL * |cutlass| + ATOL = 1e-3 + RTOL = 1e-3 + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_vs_cutlass(self, T, H, N, E, K): + """Triton BF16 MoE output must agree with CUTLASS BF16 MoE output. + + Both paths use the same weight layout, routing logic, and BF16 arithmetic. + Differences should only come from tile ordering / rounding in GEMM. + """ + from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( + CutlassMoEMethod, + ) + + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + + # CUTLASS method shares the same weights (already created by _make_precision_layer_pair) + cutlass_method = CutlassMoEMethod(None) + + paddle.seed(0) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + + # Use a deterministic non-uniform gate to ensure consistent routing + # across multiple calls of noaux_tc (avoids tie-breaking ambiguity) + class _DeterministicGate(paddle.nn.Layer): + def __init__(self, num_experts, T): + super().__init__() + self.num_experts = num_experts + paddle.seed(123) + self._scores = paddle.randn([T, num_experts], dtype="float32") * 2.0 + + def forward(self, x): + return self._scores[: x.shape[0]] + + gate = _DeterministicGate(E, T) + + # --- Run Triton path --- + triton_out = triton_method.apply(layer, x, gate).cast("float32").numpy() + + # --- Run CUTLASS path --- + cutlass_out = cutlass_method.apply(layer, x, gate).cast("float32").numpy() + + # np.allclose style: |a - b| <= atol + rtol * |b| + tol = self.ATOL + self.RTOL * np.abs(cutlass_out) + violations = np.abs(triton_out - cutlass_out) > tol + num_violations = int(violations.sum()) + total_elements = triton_out.size + + assert num_violations == 0, ( + f"[T={T},H={H},N={N},E={E},K={K}] " + f"{num_violations}/{total_elements} elements exceed tolerance " + f"(atol={self.ATOL}, rtol={self.RTOL}). " + f"Max abs diff: {float(np.abs(triton_out - cutlass_out).max()):.2e}, " + f"max |cutlass|: {float(np.abs(cutlass_out).max()):.2e}" + ) + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_shape(self, T, H, N, E, K): + """Output shape must always be [T, H] regardless of batch size.""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert list(out.shape) == [T, H], f"Expected [{T}, {H}], got {list(out.shape)}" + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_dtype_is_bfloat16(self, T, H, N, E, K): + """Output dtype must match input dtype (bfloat16).""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert out.dtype == paddle.bfloat16, f"Expected bfloat16, got {out.dtype}" + + def test_zero_input_gives_zero_output(self): + """All-zero input must produce all-zero output.""" + T, H, N, E, K = 8, 64, 32, 8, 2 + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = paddle.zeros([T, H], dtype="bfloat16") + gate = _uniform_gate(layer) + + out = triton_method.apply(layer, x, gate).cast("float32").numpy() + np.testing.assert_allclose( + out, + np.zeros_like(out), + atol=1e-6, + err_msg="triton: zero input should produce zero output", + )