Skip to content

Commit 074854b

Browse files
djmmossrobertgshaw2-redhatmgoin
authored
[Kernel][B200] mxfp4 fused cutlass moe (#23696)
Signed-off-by: Duncan Moss <[email protected]> Signed-off-by: Michael Goin <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 79ac59f commit 074854b

File tree

5 files changed

+626
-64
lines changed

5 files changed

+626
-64
lines changed

tests/kernels/moe/test_mxfp4_moe.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from packaging import version
1212

1313
from vllm.platforms import current_platform
14+
from vllm.utils.flashinfer import has_flashinfer
1415

1516
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
1617
"quark") is not None and version.parse(
@@ -19,6 +20,10 @@
1920
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
2021
) and current_platform.is_device_capability(100)
2122

23+
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
24+
and current_platform.is_device_capability(90)
25+
and has_flashinfer())
26+
2227
if TRTLLM_GEN_MXFP4_AVAILABLE:
2328
from flashinfer import (fp4_quantize, mxfp8_quantize,
2429
next_positive_power_of_2,
@@ -542,3 +547,317 @@ def test_trtllm_gen_mxfp4_fused_moe(
542547
transpose_optimized=transpose_optimized)
543548
# relatively loose check since the mxfp4 quantization is less accurate
544549
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
550+
551+
552+
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
553+
"""Interleave scales on the last dimension by groups of 4, matching
554+
the transformation in mxfp4.py's BF16 (Hopper) path."""
555+
s = scales.to(torch.uint8)
556+
s_shape = s.shape
557+
assert s_shape[-1] % 4 == 0
558+
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
559+
# Move the 4-group dimension before the row dimension
560+
permuted = s.permute(0, 2, 1, 3)
561+
# Merge the row dim with the 4-group dim
562+
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
563+
564+
565+
@pytest.mark.parametrize("topk", [1, 4])
566+
@pytest.mark.parametrize("num_experts", [32])
567+
@pytest.mark.parametrize("num_tokens", [1, 128])
568+
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
569+
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
570+
(1.702, 1.0, 7.0)])
571+
@pytest.mark.skipif(
572+
not HOPPER_MXFP4_BF16_AVAILABLE,
573+
reason="nvidia gpu sm90 and flashinfer are required for this test",
574+
)
575+
def test_flashinfer_cutlass_mxfp4_fused_moe(
576+
topk: int,
577+
num_experts: int,
578+
num_tokens: int,
579+
intermediate_size: int,
580+
hidden_size: int,
581+
alpha: float,
582+
beta: float,
583+
limit: Optional[float],
584+
):
585+
torch.manual_seed(42)
586+
device = "cuda:0"
587+
588+
# Inputs
589+
hidden_states = torch.randn(num_tokens,
590+
hidden_size,
591+
device=device,
592+
dtype=torch.bfloat16)
593+
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
594+
w13_q = torch.randint(
595+
0,
596+
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
597+
device=device,
598+
dtype=torch.uint8)
599+
w13_scale = torch.randint(
600+
118,
601+
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
602+
device=device,
603+
dtype=torch.uint8)
604+
605+
w2_q = torch.randint(0,
606+
256,
607+
(num_experts, hidden_size, intermediate_size // 2),
608+
device=device,
609+
dtype=torch.uint8)
610+
w2_scale = torch.randint(
611+
118,
612+
123, (num_experts, hidden_size, intermediate_size // 32),
613+
device=device,
614+
dtype=torch.uint8)
615+
# Bias contiguous [b1; b3]
616+
bias13 = (torch.randn(num_experts,
617+
2 * intermediate_size,
618+
device=device,
619+
dtype=torch.bfloat16) * 10)
620+
bias2 = (torch.randn(
621+
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
622+
router_logits = torch.rand(num_tokens,
623+
num_experts,
624+
dtype=torch.float32,
625+
device=device)
626+
627+
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
628+
num_experts, 2 * intermediate_size, hidden_size)
629+
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
630+
num_experts, hidden_size, intermediate_size)
631+
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
632+
hidden_states.to(torch.float32), w13_ref,
633+
bias13.to(torch.float32), w2_ref,
634+
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
635+
636+
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
637+
638+
# Swap halves to arrange as [w3; w1] (kernel expectation)
639+
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
640+
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
641+
642+
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
643+
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
644+
645+
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
646+
w13_s = torch.cat([w3_s, w1_s], dim=1)
647+
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
648+
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
649+
650+
routing_weights = torch.nn.functional.softmax(router_logits,
651+
dim=1,
652+
dtype=torch.float32)
653+
token_final_scales, token_selected_experts = torch.topk(routing_weights,
654+
topk,
655+
dim=-1)
656+
token_final_scales = (token_final_scales /
657+
token_final_scales.sum(dim=-1, keepdim=True))
658+
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
659+
660+
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
661+
if alpha is not None:
662+
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
663+
if beta is not None:
664+
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
665+
if limit is not None:
666+
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
667+
668+
_ = flashinfer_cutlass_fused_moe(
669+
input=hidden_states,
670+
token_selected_experts=token_selected_experts,
671+
token_final_scales=token_final_scales,
672+
fc1_expert_weights=w13_q_swapped,
673+
fc2_expert_weights=w2_q,
674+
output_dtype=torch.bfloat16,
675+
output=out,
676+
quant_scales=[w13_s_inter.to(torch.uint8),
677+
w2_s_inter.to(torch.uint8)],
678+
fc1_expert_biases=w13_b,
679+
fc2_expert_biases=bias2.to(torch.bfloat16),
680+
swiglu_alpha=alpha,
681+
swiglu_beta=beta,
682+
swiglu_limit=limit,
683+
tp_size=1,
684+
tp_rank=0,
685+
ep_size=1,
686+
ep_rank=0,
687+
use_w4_group_scaling=True,
688+
)
689+
690+
# Allow some mismatch due to MXFP4 quantization
691+
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
692+
693+
694+
@pytest.mark.parametrize("topk", [1, 4])
695+
@pytest.mark.parametrize("num_experts", [32])
696+
@pytest.mark.parametrize("num_tokens", [1, 128])
697+
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
698+
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
699+
(1.702, 1.0, 7.0)])
700+
@pytest.mark.skipif(
701+
not (current_platform.is_cuda()
702+
and current_platform.is_device_capability(100) and has_flashinfer()),
703+
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
704+
)
705+
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
706+
topk: int,
707+
num_experts: int,
708+
num_tokens: int,
709+
intermediate_size: int,
710+
hidden_size: int,
711+
alpha: Optional[float],
712+
beta: Optional[float],
713+
limit: Optional[float],
714+
):
715+
torch.manual_seed(42)
716+
device = "cuda:0"
717+
718+
# Inputs
719+
hidden_states = torch.randn(num_tokens,
720+
hidden_size,
721+
device=device,
722+
dtype=torch.bfloat16)
723+
# Float weights in w13 format [w1; w3]
724+
w13 = (torch.randn(num_experts,
725+
2 * intermediate_size,
726+
hidden_size,
727+
device=device,
728+
dtype=torch.bfloat16) / 10)
729+
w2 = (torch.randn(num_experts,
730+
hidden_size,
731+
intermediate_size,
732+
device=device,
733+
dtype=torch.bfloat16) / 10)
734+
# Bias contiguous [b1; b3]
735+
bias13 = (torch.randn(num_experts,
736+
2 * intermediate_size,
737+
device=device,
738+
dtype=torch.bfloat16) * 10)
739+
bias2 = (torch.randn(
740+
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
741+
router_logits = torch.rand(num_tokens,
742+
num_experts,
743+
dtype=torch.float32,
744+
device=device)
745+
746+
# Quantize weights to MXFP4 per expert (SM100 path)
747+
from flashinfer import mxfp4_quantize
748+
749+
def quant_mxfp4_batches(a: torch.Tensor, e: int):
750+
qs, sfs = [], []
751+
for i in range(e):
752+
q, sf = mxfp4_quantize(a[i].cuda())
753+
qs.append(q)
754+
sfs.append(sf)
755+
return torch.stack(qs), torch.stack(sfs)
756+
757+
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
758+
scale_tensor: torch.Tensor):
759+
num_batches = mat_fp4.size(0)
760+
scale_tensor = scale_tensor.view(num_batches, -1)
761+
from flashinfer import mxfp4_dequantize
762+
return torch.stack([
763+
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
764+
for b in range(num_batches)
765+
])
766+
767+
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
768+
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
769+
770+
# Reference result using dequantized tensors and reference_moe
771+
w13_ref = dequant_mxfp4_batches(
772+
w13_q.view(torch.uint8),
773+
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
774+
num_experts, 2 * intermediate_size, hidden_size)
775+
w2_ref = dequant_mxfp4_batches(
776+
w2_q.view(torch.uint8),
777+
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
778+
num_experts, hidden_size, intermediate_size)
779+
780+
# Quantize activations for SM100 path and dequantize for reference
781+
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
782+
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
783+
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
784+
hidden_states.to(torch.float32), w13_ref,
785+
bias13.to(torch.float32), w2_ref,
786+
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
787+
788+
# Prepare inputs for FlashInfer CUTLASS fused MoE
789+
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
790+
791+
# Swap halves to arrange as [w3; w1] (kernel expectation)
792+
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
793+
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
794+
795+
# Swap scales halves to match swapped weights
796+
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
797+
w13_scale_swapped = torch.cat([s3, s1], dim=1)
798+
799+
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
800+
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
801+
802+
# Build routing for kernel
803+
routing_weights = torch.nn.functional.softmax(router_logits,
804+
dim=1,
805+
dtype=torch.float32)
806+
token_final_scales, token_selected_experts = torch.topk(routing_weights,
807+
topk,
808+
dim=-1)
809+
token_final_scales = (token_final_scales /
810+
token_final_scales.sum(dim=-1, keepdim=True))
811+
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
812+
813+
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
814+
if alpha is not None:
815+
alpha_t = torch.full((num_experts, ),
816+
alpha,
817+
device=hidden_states.device)
818+
else:
819+
alpha_t = None
820+
if beta is not None:
821+
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
822+
else:
823+
beta_t = None
824+
if limit is not None:
825+
limit_t = torch.full((num_experts, ),
826+
limit,
827+
device=hidden_states.device)
828+
else:
829+
limit_t = None
830+
831+
# Quant scales for SM100 MXFP8+MXFP4 path
832+
fake_input_scale = torch.ones(num_experts, device=device)
833+
quant_scales = [
834+
w13_scale_swapped.view(torch.int32),
835+
fake_input_scale,
836+
w2_scale.view(torch.int32),
837+
fake_input_scale,
838+
]
839+
840+
_ = flashinfer_cutlass_fused_moe(
841+
input=hidden_states_q,
842+
token_selected_experts=token_selected_experts,
843+
token_final_scales=token_final_scales,
844+
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
845+
fc2_expert_weights=w2_q.contiguous().view(torch.long),
846+
output_dtype=torch.bfloat16,
847+
output=out,
848+
quant_scales=quant_scales,
849+
fc1_expert_biases=w13_b,
850+
fc2_expert_biases=bias2.to(torch.bfloat16),
851+
swiglu_alpha=alpha_t,
852+
swiglu_beta=beta_t,
853+
swiglu_limit=limit_t,
854+
tp_size=1,
855+
tp_rank=0,
856+
ep_size=1,
857+
ep_rank=0,
858+
use_mxfp8_act_scaling=True,
859+
input_sf=hidden_states_sf,
860+
)
861+
862+
# Allow some mismatch due to MXFP4 quantization
863+
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)

vllm/envs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@
166166
VLLM_HAS_FLASHINFER_CUBIN: bool = False
167167
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
168168
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
169-
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
169+
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
170+
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
170171
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
171172
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
172173
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
@@ -1004,6 +1005,15 @@ def get_vllm_port() -> Optional[int]:
10041005
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
10051006
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
10061007

1008+
# If set to 1, use the FlashInfer CUTLASS backend for
1009+
# MXFP8 (activation) x MXFP4 (weight) MoE.
1010+
# This is separate from the TRTLLMGEN path controlled by
1011+
# VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8.
1012+
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS":
1013+
lambda: bool(int(
1014+
os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")
1015+
)),
1016+
10071017
# If set to 1, use the FlashInfer
10081018
# BF16 (activation) x MXFP4 (weight) MoE backend.
10091019
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
@@ -1296,6 +1306,7 @@ def compute_hash() -> str:
12961306
"VLLM_USE_FLASHINFER_MOE_FP8",
12971307
"VLLM_USE_FLASHINFER_MOE_FP4",
12981308
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
1309+
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
12991310
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
13001311
"VLLM_USE_CUDNN_PREFILL",
13011312
"VLLM_USE_TRTLLM_ATTENTION",

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,16 @@ def __init__(
813813

814814
# we are padding globally so EP buffer allocation works
815815
if quant_config and quant_config.get_name() == "mxfp4":
816-
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
817-
should_use_flashinfer_mxfp4)
818-
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
816+
from vllm.model_executor.layers.quantization.mxfp4 import (
817+
Mxfp4Backend, get_mxfp4_backend)
818+
current_mxfp4_backend = get_mxfp4_backend()
819+
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
820+
or current_mxfp4_backend
821+
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
822+
hidden_size = round_up(hidden_size, 128)
823+
elif (current_platform.is_rocm() or current_mxfp4_backend
824+
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
825+
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
819826
hidden_size = round_up(hidden_size, 256)
820827

821828
# For smuggling this layer into the fused moe custom op

0 commit comments

Comments
 (0)