Skip to content
147 changes: 147 additions & 0 deletions tests/kernels/moe/test_mxfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from packaging import version

from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer

QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
Expand All @@ -19,6 +20,10 @@
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)

HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer())

if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
Expand Down Expand Up @@ -473,3 +478,145 @@ def test_trtllm_gen_mxfp4_fused_moe(
limit=limit)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)


def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales on the last dimension by groups of 4, matching
the transformation in mxfp4.py's BF16 (Hopper) path."""
s = scales.to(torch.uint8)
s_shape = s.shape
assert s_shape[-1] % 4 == 0
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
# Move the 4-group dimension before the row dimension
permuted = s.permute(0, 2, 1, 3)
# Merge the row dim with the 4-group dim
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)


@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not HOPPER_MXFP4_BF16_AVAILABLE,
reason="nvidia gpu sm90 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
):
torch.manual_seed(42)
device = "cuda:0"

# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
w13_q = torch.randint(
0,
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
device=device,
dtype=torch.uint8)
w13_scale = torch.randint(
118,
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
device=device,
dtype=torch.uint8)

w2_q = torch.randint(0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8)
w2_scale = torch.randint(
118,
123, (num_experts, hidden_size, intermediate_size // 32),
device=device,
dtype=torch.uint8)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)

w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
num_experts, 2 * intermediate_size, hidden_size)
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
num_experts, hidden_size, intermediate_size)
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'bf16')

from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)

b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
w13_s = torch.cat([w3_s, w1_s], dim=1)
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)

routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
token_selected_experts = token_selected_experts.to(torch.int).contiguous()

out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)

_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
fc1_expert_weights=w13_q_swapped,
fc2_expert_weights=w2_q,
output_dtype=torch.bfloat16,
output=out,
quant_scales=[w13_s_inter.to(torch.uint8),
w2_s_inter.to(torch.uint8)],
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha,
swiglu_beta=beta,
swiglu_limit=limit,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
use_w4_group_scaling=True,
)

# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
9 changes: 7 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,14 @@ def __init__(
# we padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
should_use_flashinfer_mxfp4)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
should_use_flashinfer_mxfp4, should_use_flashinfer_mxfp4_bf16)
if current_platform.is_rocm() or (
should_use_flashinfer_mxfp4()
and current_platform.is_device_capability(100)):
hidden_size = round_up(hidden_size, 256)
elif should_use_flashinfer_mxfp4_bf16(
) and current_platform.is_device_capability(90):
hidden_size = round_up(hidden_size, 128)

# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
Expand Down
Loading