From 4e10390a10ec43fab6ff718295a3f8e1143b167e Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 21 Aug 2025 12:37:28 -0700 Subject: [PATCH 01/10] feat: fused cutlass moe for mxfp4 on hopper Signed-off-by: Duncan Moss --- .../layers/quantization/mxfp4.py | 145 +++++++++++++++++- 1 file changed, 141 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6a190ebbc063..678f0d494e47 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -39,10 +39,10 @@ def _should_use_flashinfer_mxfp4_bf16(): return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 # Enable by default on SM100 if MXFP8 is not explicitly enabled - if (current_platform.is_device_capability(100) and has_flashinfer() + if (current_platform.is_device_capability(100) or current_platform.is_device_capability(90) and has_flashinfer() and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): logger.info_once( - "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " + "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper. " "For faster performance, consider setting " "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " "though this may impact accuracy.") @@ -113,6 +113,7 @@ def __init__(self, moe: FusedMoEConfig): self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() + self.flashinfer_autotune = True if current_platform.is_device_capability(100) and not has_flashinfer(): logger.warning_once( @@ -171,13 +172,16 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif should_use_flashinfer_mxfp4(): + elif should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) + elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) elif current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) @@ -384,6 +388,96 @@ def swap_every_two_rows(x, axis=-1): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) + elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + assert layer.w13_weight.dtype == torch.uint8, f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, expected: {torch.uint8}" + assert layer.w2_weight.dtype == torch.uint8, f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, expected: {torch.uint8}" + assert layer.w13_weight_scale.dtype == torch.uint8, f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, expected: {torch.uint8}" + assert layer.w2_weight_scale.dtype == torch.uint8, f"layer.w2_weight_scale.dtype: {layer.w2_weight_scale.dtype}, expected: {torch.uint8}" + assert layer.w13_bias.dtype == torch.bfloat16, f"layer.w13_bias.dtype: {layer.w13_bias.dtype}, expected: {torch.bfloat16}" + assert layer.w2_bias.dtype == torch.bfloat16, f"layer.w2_bias.dtype: {layer.w2_bias.dtype}, expected: {torch.bfloat16}" + + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + sf_block_size = 32 # mxfp4 block size + + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + + + # De-interleave weights, scales, and biases for gate and up projections + w13_weight_data = layer.w13_weight.data + gate_w, up_w = w13_weight_data[:, ::2, :], w13_weight_data[:, 1::2, :] + deinterleaved_w13_weight = torch.cat([gate_w, up_w], dim=1) + w1_weight, w3_weight = torch.chunk(deinterleaved_w13_weight, 2, dim=1) + layer.w13_weight = torch.nn.Parameter(torch.cat([w3_weight, w1_weight], dim=1).cuda(), requires_grad=False) + + w13_bias_data = layer.w13_bias.data.to(torch.float32) + gate_b, up_b = w13_bias_data[:, ::2], w13_bias_data[:, 1::2] + deinterleaved_w13_bias = torch.cat([gate_b, up_b], dim=1) + b1, b3 = torch.chunk(deinterleaved_w13_bias, 2, dim=-1) + b = torch.cat([b3, b1], dim=-1) + layer.w13_bias = torch.nn.Parameter(b.to(torch.bfloat16).cuda(), requires_grad=False) + + # Scale + w13_scale_data = layer.w13_weight_scale.data + gate_s, up_s = w13_scale_data[:, ::2, :], w13_scale_data[:, 1::2, :] + deinterleaved_w13_scale = torch.cat([gate_s, up_s], dim=1) + w1_weight_scale, w3_weight_scale = torch.chunk(deinterleaved_w13_scale, 2, dim=1) + all_w31_scales = torch.cat([w3_weight_scale, w1_weight_scale], dim=1) + + w31_scales = all_w31_scales.to(torch.uint8).view(torch.uint8) + w31_s_shape = w31_scales.shape + w31_scales_interleaved = w31_scales.reshape( + w31_s_shape[0], w31_s_shape[1], + (w31_s_shape[2] // 4), 4) + w31_scales_interleaved = w31_scales_interleaved.permute(0, 2, 1, 3) + w31_scales_interleaved = w31_scales_interleaved.reshape( + w31_s_shape[0], w31_s_shape[2] // 4, w31_s_shape[1] * 4) + + layer.w13_weight_scale = torch.nn.Parameter(w31_scales_interleaved.cuda(), requires_grad=False) + + w2_weight_scale = layer.w2_weight_scale.data + w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) + w2_s_shape = w2_scales.shape + w2_scales_interleaved = w2_scales.reshape( + w2_s_shape[0], w2_s_shape[1], + (w2_s_shape[2] // 4), 4) + w2_scales_interleaved = w2_scales_interleaved.permute(0, 2, 1, 3) + w2_scales_interleaved = w2_scales_interleaved.reshape( + w2_s_shape[0], w2_s_shape[2] // 4, w2_s_shape[1] * 4) + + layer.w2_weight_scale = torch.nn.Parameter(w2_scales_interleaved.cuda(), requires_grad=False) + else: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -510,7 +604,7 @@ def apply( logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if should_use_flashinfer_mxfp4(): + if should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100): from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe assert not self.moe.use_ep, ( "EP is not supported for flashinfer mxfp4 moe backend yet.") @@ -551,6 +645,49 @@ def apply( True, # do finalize )[0] return trtllm_gen_output + elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + + assert x.dtype == torch.bfloat16 + + quant_scales = [ + layer.w13_weight_scale, + layer.w2_weight_scale, + ] + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + output = torch.zeros_like(x) + + with torch.inference_mode(), autotune(self.flashinfer_autotune): + _ = cutlass_fused_moe( + input=x, + token_selected_experts=topk_ids, + token_final_scales=topk_weights, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + output_dtype=torch.bfloat16, + quant_scales=quant_scales, + fc1_expert_biases=layer.w13_bias, + fc2_expert_biases=layer.w2_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + use_w4_group_scaling=True, + output=output, + ) + self.flashinfer_autotune = False + return output else: return triton_kernel_moe_forward( hidden_states=x, From 7d26ada1cb5a9bed9727b6d5d7e02019812d2ecf Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 21 Aug 2025 13:40:03 -0700 Subject: [PATCH 02/10] tp and ep fixes Signed-off-by: Duncan Moss --- vllm/model_executor/layers/quantization/mxfp4.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 678f0d494e47..4f5222a79617 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -266,7 +266,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def process_weights_after_loading(self, layer): if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - elif should_use_flashinfer_mxfp4(): + elif should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100): from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), @@ -646,6 +646,8 @@ def apply( )[0] return trtllm_gen_output elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + from flashinfer import cutlass_fused_moe + from flashinfer.autotuner import autotune assert x.dtype == torch.bfloat16 @@ -683,6 +685,10 @@ def apply( swiglu_alpha=layer.gemm1_alpha, swiglu_beta=layer.gemm1_beta, swiglu_limit=layer.gemm1_clamp_limit, + tp_size=self.moe.tp_size, + tp_rank=self.moe.tp_rank, + ep_size=self.moe.ep_size, + ep_rank=self.moe.ep_rank, use_w4_group_scaling=True, output=output, ) From d57f6e7e0b61f8cef1ecabb96cf70516b7881d02 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 21 Aug 2025 18:52:58 -0700 Subject: [PATCH 03/10] pre-commit fixes Signed-off-by: Duncan Moss --- .../layers/quantization/mxfp4.py | 78 +++++++++++-------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4f5222a79617..4fedb2abaa11 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -39,7 +39,8 @@ def _should_use_flashinfer_mxfp4_bf16(): return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 # Enable by default on SM100 if MXFP8 is not explicitly enabled - if (current_platform.is_device_capability(100) or current_platform.is_device_capability(90) and has_flashinfer() + if (current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) and has_flashinfer() and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): logger.info_once( "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper. " @@ -172,17 +173,17 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100): + elif should_use_flashinfer_mxfp4( + ) and current_platform.is_device_capability(100): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128) - elif current_platform.is_rocm(): + elif _should_use_flashinfer_mxfp4_bf16( + ) and current_platform.is_device_capability( + 90) or current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) else: @@ -266,7 +267,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def process_weights_after_loading(self, layer): if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - elif should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100): + elif should_use_flashinfer_mxfp4( + ) and current_platform.is_device_capability(100): from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), @@ -388,14 +390,15 @@ def swap_every_two_rows(x, axis=-1): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) - elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + elif _should_use_flashinfer_mxfp4_bf16( + ) and current_platform.is_device_capability(90): assert layer.w13_weight.dtype == torch.uint8, f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, expected: {torch.uint8}" assert layer.w2_weight.dtype == torch.uint8, f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, expected: {torch.uint8}" assert layer.w13_weight_scale.dtype == torch.uint8, f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, expected: {torch.uint8}" assert layer.w2_weight_scale.dtype == torch.uint8, f"layer.w2_weight_scale.dtype: {layer.w2_weight_scale.dtype}, expected: {torch.uint8}" assert layer.w13_bias.dtype == torch.bfloat16, f"layer.w13_bias.dtype: {layer.w13_bias.dtype}, expected: {torch.bfloat16}" assert layer.w2_bias.dtype == torch.bfloat16, f"layer.w2_bias.dtype: {layer.w2_bias.dtype}, expected: {torch.bfloat16}" - + layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -432,51 +435,59 @@ def swap_every_two_rows(x, axis=-1): and layer.w2_bias.shape[0] == self.num_experts and layer.w2_bias.shape[1] == self.hidden_size) - - # De-interleave weights, scales, and biases for gate and up projections w13_weight_data = layer.w13_weight.data - gate_w, up_w = w13_weight_data[:, ::2, :], w13_weight_data[:, 1::2, :] + gate_w, up_w = w13_weight_data[:, ::2, :], w13_weight_data[:, + 1::2, :] deinterleaved_w13_weight = torch.cat([gate_w, up_w], dim=1) - w1_weight, w3_weight = torch.chunk(deinterleaved_w13_weight, 2, dim=1) - layer.w13_weight = torch.nn.Parameter(torch.cat([w3_weight, w1_weight], dim=1).cuda(), requires_grad=False) + w1_weight, w3_weight = torch.chunk(deinterleaved_w13_weight, + 2, + dim=1) + layer.w13_weight = torch.nn.Parameter(torch.cat( + [w3_weight, w1_weight], dim=1).cuda(), + requires_grad=False) w13_bias_data = layer.w13_bias.data.to(torch.float32) gate_b, up_b = w13_bias_data[:, ::2], w13_bias_data[:, 1::2] deinterleaved_w13_bias = torch.cat([gate_b, up_b], dim=1) b1, b3 = torch.chunk(deinterleaved_w13_bias, 2, dim=-1) b = torch.cat([b3, b1], dim=-1) - layer.w13_bias = torch.nn.Parameter(b.to(torch.bfloat16).cuda(), requires_grad=False) + layer.w13_bias = torch.nn.Parameter(b.to(torch.bfloat16).cuda(), + requires_grad=False) # Scale w13_scale_data = layer.w13_weight_scale.data - gate_s, up_s = w13_scale_data[:, ::2, :], w13_scale_data[:, 1::2, :] + gate_s, up_s = w13_scale_data[:, ::2, :], w13_scale_data[:, + 1::2, :] deinterleaved_w13_scale = torch.cat([gate_s, up_s], dim=1) - w1_weight_scale, w3_weight_scale = torch.chunk(deinterleaved_w13_scale, 2, dim=1) - all_w31_scales = torch.cat([w3_weight_scale, w1_weight_scale], dim=1) + w1_weight_scale, w3_weight_scale = torch.chunk( + deinterleaved_w13_scale, 2, dim=1) + all_w31_scales = torch.cat([w3_weight_scale, w1_weight_scale], + dim=1) w31_scales = all_w31_scales.to(torch.uint8).view(torch.uint8) w31_s_shape = w31_scales.shape w31_scales_interleaved = w31_scales.reshape( - w31_s_shape[0], w31_s_shape[1], - (w31_s_shape[2] // 4), 4) + w31_s_shape[0], w31_s_shape[1], (w31_s_shape[2] // 4), 4) w31_scales_interleaved = w31_scales_interleaved.permute(0, 2, 1, 3) w31_scales_interleaved = w31_scales_interleaved.reshape( w31_s_shape[0], w31_s_shape[2] // 4, w31_s_shape[1] * 4) - layer.w13_weight_scale = torch.nn.Parameter(w31_scales_interleaved.cuda(), requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w31_scales_interleaved.cuda(), requires_grad=False) w2_weight_scale = layer.w2_weight_scale.data w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) w2_s_shape = w2_scales.shape - w2_scales_interleaved = w2_scales.reshape( - w2_s_shape[0], w2_s_shape[1], - (w2_s_shape[2] // 4), 4) + w2_scales_interleaved = w2_scales.reshape(w2_s_shape[0], + w2_s_shape[1], + (w2_s_shape[2] // 4), 4) w2_scales_interleaved = w2_scales_interleaved.permute(0, 2, 1, 3) w2_scales_interleaved = w2_scales_interleaved.reshape( w2_s_shape[0], w2_s_shape[2] // 4, w2_s_shape[1] * 4) - layer.w2_weight_scale = torch.nn.Parameter(w2_scales_interleaved.cuda(), requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales_interleaved.cuda(), requires_grad=False) else: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -604,7 +615,8 @@ def apply( logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100): + if should_use_flashinfer_mxfp4( + ) and current_platform.is_device_capability(100): from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe assert not self.moe.use_ep, ( "EP is not supported for flashinfer mxfp4 moe backend yet.") @@ -645,9 +657,10 @@ def apply( True, # do finalize )[0] return trtllm_gen_output - elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): - from flashinfer import cutlass_fused_moe - from flashinfer.autotuner import autotune + elif _should_use_flashinfer_mxfp4_bf16( + ) and current_platform.is_device_capability(90): + from vllm.utils.flashinfer import (autotune, + flashinfer_cutlass_fused_moe) assert x.dtype == torch.bfloat16 @@ -669,10 +682,8 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - output = torch.zeros_like(x) - with torch.inference_mode(), autotune(self.flashinfer_autotune): - _ = cutlass_fused_moe( + output = flashinfer_cutlass_fused_moe( input=x, token_selected_experts=topk_ids, token_final_scales=topk_weights, @@ -690,8 +701,7 @@ def apply( ep_size=self.moe.ep_size, ep_rank=self.moe.ep_rank, use_w4_group_scaling=True, - output=output, - ) + )[0] self.flashinfer_autotune = False return output else: From 0fea8cc7395d3af1b2a1d3f8416b558b176eb32e Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 21 Aug 2025 19:24:46 -0700 Subject: [PATCH 04/10] more precomiit Signed-off-by: Duncan Moss --- .../layers/quantization/mxfp4.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4fedb2abaa11..6cd91cf41ddd 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -43,8 +43,9 @@ def _should_use_flashinfer_mxfp4_bf16(): or current_platform.is_device_capability(90) and has_flashinfer() and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): logger.info_once( - "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper. " - "For faster performance, consider setting " + "Enabling FlashInfer MXFP4 BF16 backend by " + "default for Blackwell and Hopper. " + "For faster performance on Blackwell, consider setting " "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " "though this may impact accuracy.") return True @@ -392,12 +393,24 @@ def swap_every_two_rows(x, axis=-1): requires_grad=False) elif _should_use_flashinfer_mxfp4_bf16( ) and current_platform.is_device_capability(90): - assert layer.w13_weight.dtype == torch.uint8, f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, expected: {torch.uint8}" - assert layer.w2_weight.dtype == torch.uint8, f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, expected: {torch.uint8}" - assert layer.w13_weight_scale.dtype == torch.uint8, f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, expected: {torch.uint8}" - assert layer.w2_weight_scale.dtype == torch.uint8, f"layer.w2_weight_scale.dtype: {layer.w2_weight_scale.dtype}, expected: {torch.uint8}" - assert layer.w13_bias.dtype == torch.bfloat16, f"layer.w13_bias.dtype: {layer.w13_bias.dtype}, expected: {torch.bfloat16}" - assert layer.w2_bias.dtype == torch.bfloat16, f"layer.w2_bias.dtype: {layer.w2_bias.dtype}, expected: {torch.bfloat16}" + assert layer.w13_weight.dtype == torch.uint8, ( + f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, " + f"expected: {torch.uint8}") + assert layer.w2_weight.dtype == torch.uint8, ( + f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, " + f"expected: {torch.uint8}") + assert layer.w13_weight_scale.dtype == torch.uint8, ( + f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, " # noqa: E501 + f"expected: {torch.uint8}") + assert layer.w2_weight_scale.dtype == torch.uint8, ( + f"layer.w2_weight_scale.dtype: {layer.w2_weight_scale.dtype}, " # noqa: E501 + f"expected: {torch.uint8}") + assert layer.w13_bias.dtype == torch.bfloat16, ( + f"layer.w13_bias.dtype: {layer.w13_bias.dtype}, " + f"expected: {torch.bfloat16}") + assert layer.w2_bias.dtype == torch.bfloat16, ( + f"layer.w2_bias.dtype: {layer.w2_bias.dtype}, " + f"expected: {torch.bfloat16}") layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), @@ -435,7 +448,7 @@ def swap_every_two_rows(x, axis=-1): and layer.w2_bias.shape[0] == self.num_experts and layer.w2_bias.shape[1] == self.hidden_size) - # De-interleave weights, scales, and biases for gate and up projections + # De-interleave weights, scales, and biases w13_weight_data = layer.w13_weight.data gate_w, up_w = w13_weight_data[:, ::2, :], w13_weight_data[:, 1::2, :] From 2834eb1b062eb029f76c04fd1178075b5e25a604 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Sun, 24 Aug 2025 09:23:57 -0700 Subject: [PATCH 05/10] final updates Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/layer.py | 6 +++-- .../layers/quantization/mxfp4.py | 27 ++++++++++--------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fcc6987d26bb..b2c9097e39cb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -791,9 +791,11 @@ def __init__( # we padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 - should_use_flashinfer_mxfp4) - if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + should_use_flashinfer_mxfp4, should_use_flashinfer_mxfp4_bf16) + if current_platform.is_rocm() or (should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100)): hidden_size = round_up(hidden_size, 256) + elif should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + hidden_size = round_up(hidden_size, 128) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6cd91cf41ddd..fc8a991299e6 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -32,7 +32,7 @@ logger = init_logger(__name__) -def _should_use_flashinfer_mxfp4_bf16(): +def should_use_flashinfer_mxfp4_bf16(): """Determine if FlashInfer MXFP4 BF16 should be used.""" # If explicitly set, respect the setting if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): @@ -60,7 +60,7 @@ def _should_use_flashinfer_mxfp4_mxfp8(): def should_use_flashinfer_mxfp4(): return (_should_use_flashinfer_mxfp4_mxfp8() - or _should_use_flashinfer_mxfp4_bf16()) + or should_use_flashinfer_mxfp4_bf16()) class Mxfp4Config(QuantizationConfig): @@ -182,11 +182,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif _should_use_flashinfer_mxfp4_bf16( + elif should_use_flashinfer_mxfp4_bf16( ) and current_platform.is_device_capability( 90) or current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 128) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64) @@ -391,7 +392,7 @@ def swap_every_two_rows(x, axis=-1): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) - elif _should_use_flashinfer_mxfp4_bf16( + elif should_use_flashinfer_mxfp4_bf16( ) and current_platform.is_device_capability(90): assert layer.w13_weight.dtype == torch.uint8, ( f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, " @@ -501,7 +502,6 @@ def swap_every_two_rows(x, axis=-1): layer.w2_weight_scale = torch.nn.Parameter( w2_scales_interleaved.cuda(), requires_grad=False) - else: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -633,7 +633,7 @@ def apply( from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe assert not self.moe.use_ep, ( "EP is not supported for flashinfer mxfp4 moe backend yet.") - if _should_use_flashinfer_mxfp4_bf16(): + if should_use_flashinfer_mxfp4_bf16(): assert x.dtype == torch.bfloat16 x_quant = x x_scale = None @@ -670,7 +670,7 @@ def apply( True, # do finalize )[0] return trtllm_gen_output - elif _should_use_flashinfer_mxfp4_bf16( + elif should_use_flashinfer_mxfp4_bf16( ) and current_platform.is_device_capability(90): from vllm.utils.flashinfer import (autotune, flashinfer_cutlass_fused_moe) @@ -695,14 +695,16 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - with torch.inference_mode(), autotune(self.flashinfer_autotune): - output = flashinfer_cutlass_fused_moe( + output = torch.empty_like(x, dtype=torch.bfloat16) + with autotune(self.flashinfer_autotune): + _ = flashinfer_cutlass_fused_moe( input=x, - token_selected_experts=topk_ids, + token_selected_experts=topk_ids.to(torch.int).contiguous(), token_final_scales=topk_weights, fc1_expert_weights=layer.w13_weight, fc2_expert_weights=layer.w2_weight, output_dtype=torch.bfloat16, + output=output, quant_scales=quant_scales, fc1_expert_biases=layer.w13_bias, fc2_expert_biases=layer.w2_bias, @@ -714,8 +716,9 @@ def apply( ep_size=self.moe.ep_size, ep_rank=self.moe.ep_rank, use_w4_group_scaling=True, - )[0] - self.flashinfer_autotune = False + ) + + self.flashinfer_autotune = False return output else: return triton_kernel_moe_forward( From 927c179bb577736ecaff3ee26ad5bc8458abb208 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Sun, 24 Aug 2025 09:39:42 -0700 Subject: [PATCH 06/10] pre-commit fixes Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/layer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b2c9097e39cb..aef756fb802b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -792,9 +792,12 @@ def __init__( if quant_config and quant_config.get_name() == "mxfp4": from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 should_use_flashinfer_mxfp4, should_use_flashinfer_mxfp4_bf16) - if current_platform.is_rocm() or (should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100)): + if current_platform.is_rocm() or ( + should_use_flashinfer_mxfp4() + and current_platform.is_device_capability(100)): hidden_size = round_up(hidden_size, 256) - elif should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90): + elif should_use_flashinfer_mxfp4_bf16( + ) and current_platform.is_device_capability(90): hidden_size = round_up(hidden_size, 128) # For smuggling this layer into the fused moe custom op From 8588b56490d1a25a8a26fbea38aca8fd5f0d3954 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 25 Aug 2025 11:04:02 -0700 Subject: [PATCH 07/10] add unit test Signed-off-by: Duncan Moss --- tests/kernels/moe/test_mxfp4_moe.py | 136 ++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 7bd1ffce58e9..3d4cffc102fd 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -11,6 +11,7 @@ from packaging import version from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( @@ -19,6 +20,10 @@ TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( ) and current_platform.is_device_capability(100) +HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and has_flashinfer()) + if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, @@ -473,3 +478,134 @@ def test_trtllm_gen_mxfp4_fused_moe( limit=limit) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) + + +def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: + """Interleave scales on the last dimension by groups of 4, matching + the transformation in mxfp4.py's BF16 (Hopper) path.""" + s = scales.to(torch.uint8) + s_shape = s.shape + assert s_shape[-1] % 4 == 0 + s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) + # Move the 4-group dimension before the row dimension + permuted = s.permute(0, 2, 1, 3) + # Merge the row dim with the 4-group dim + return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) + + + + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not HOPPER_MXFP4_BF16_AVAILABLE, + reason="nvidia gpu sm90 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, hidden_size, + device=device, + dtype=torch.bfloat16) + # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] + w13_q = torch.randint(0, 256, (num_experts, 2 * intermediate_size, hidden_size // 2), + device=device, dtype=torch.uint8) + w13_scale = torch.randint(118, 123, (num_experts, 2 * intermediate_size, hidden_size // 32), + device=device, dtype=torch.uint8) + + w2_q = torch.randint(0, 256, (num_experts, hidden_size, intermediate_size // 2), + device=device, dtype=torch.uint8) + w2_scale = torch.randint(118, 123, (num_experts, hidden_size, intermediate_size // 32), + device=device, dtype=torch.uint8) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, 2 * intermediate_size, device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn(num_experts, hidden_size, device=device, + dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, num_experts, + dtype=torch.float32, + device=device) + + # Reference: dequantize MXFP4 weights back to fp32 using the same packed layout + w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( + num_experts, 2 * intermediate_size, hidden_size) + w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( + num_experts, hidden_size, intermediate_size) + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'bf16') + + # FlashInfer (BF16 Hopper) path (inline) + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) + w13_s = torch.cat([w3_s, w1_s], dim=1) + w13_s_inter = _interleave_scales_lastdim_by4(w13_s) + w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) + + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped, + fc2_expert_weights=w2_q, + output_dtype=torch.bfloat16, + output=out, + quant_scales=[w13_s_inter.to(torch.uint8), + w2_s_inter.to(torch.uint8)], + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha, + swiglu_beta=beta, + swiglu_limit=limit, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_w4_group_scaling=True, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) From aa9facb4ff01e70109007f125ea2cadbd1bf60d5 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 25 Aug 2025 13:37:12 -0700 Subject: [PATCH 08/10] lint Signed-off-by: Duncan Moss --- tests/kernels/moe/test_mxfp4_moe.py | 49 ++++++++++++++++++----------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 3d4cffc102fd..a69055cb34db 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -493,9 +493,6 @@ def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) - - - @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_tokens", [1, 128]) @@ -520,29 +517,44 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( device = "cuda:0" # Inputs - hidden_states = torch.randn(num_tokens, hidden_size, + hidden_states = torch.randn(num_tokens, + hidden_size, device=device, dtype=torch.bfloat16) # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] - w13_q = torch.randint(0, 256, (num_experts, 2 * intermediate_size, hidden_size // 2), - device=device, dtype=torch.uint8) - w13_scale = torch.randint(118, 123, (num_experts, 2 * intermediate_size, hidden_size // 32), - device=device, dtype=torch.uint8) - - w2_q = torch.randint(0, 256, (num_experts, hidden_size, intermediate_size // 2), - device=device, dtype=torch.uint8) - w2_scale = torch.randint(118, 123, (num_experts, hidden_size, intermediate_size // 32), - device=device, dtype=torch.uint8) + w13_q = torch.randint( + 0, + 256, (num_experts, 2 * intermediate_size, hidden_size // 2), + device=device, + dtype=torch.uint8) + w13_scale = torch.randint( + 118, + 123, (num_experts, 2 * intermediate_size, hidden_size // 32), + device=device, + dtype=torch.uint8) + + w2_q = torch.randint(0, + 256, + (num_experts, hidden_size, intermediate_size // 2), + device=device, + dtype=torch.uint8) + w2_scale = torch.randint( + 118, + 123, (num_experts, hidden_size, intermediate_size // 32), + device=device, + dtype=torch.uint8) # Bias contiguous [b1; b3] - bias13 = (torch.randn(num_experts, 2 * intermediate_size, device=device, + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, dtype=torch.bfloat16) * 10) - bias2 = (torch.randn(num_experts, hidden_size, device=device, - dtype=torch.bfloat16) * 10) - router_logits = torch.rand(num_tokens, num_experts, + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, dtype=torch.float32, device=device) - # Reference: dequantize MXFP4 weights back to fp32 using the same packed layout w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( num_experts, 2 * intermediate_size, hidden_size) w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( @@ -552,7 +564,6 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( bias13.to(torch.float32), w2_ref, bias2.to(torch.float32), alpha, beta, limit, 'bf16') - # FlashInfer (BF16 Hopper) path (inline) from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe # Swap halves to arrange as [w3; w1] (kernel expectation) From 239647d5e8260ed79eb93a544b9f265cac1fa1e9 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 25 Aug 2025 13:42:52 -0700 Subject: [PATCH 09/10] bump flashinfer version Signed-off-by: Duncan Moss --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ca6e0a8592cc..ffe8ec4e79af 100644 --- a/setup.py +++ b/setup.py @@ -694,7 +694,7 @@ def _read_requirements(filename: str) -> list[str]: "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.12"], + "flashinfer": ["flashinfer-python==0.2.14.post1"], # Optional deps for AMD FP4 quantization support "petit-kernel": ["petit-kernel"], }, From 0d4bc69c641dc6fd9c1f7f02806cf3f6b7f2cc35 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 25 Aug 2025 14:34:53 -0700 Subject: [PATCH 10/10] udpated docker flashinfer ref Signed-off-by: Duncan Moss --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 839ac501dbaf..2e272cbca841 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -373,7 +373,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" # Keep this in sync with "flashinfer" extra in setup.py -ARG FLASHINFER_GIT_REF="v0.2.12" +ARG FLASHINFER_GIT_REF="v0.2.14.post1" # Flag to control whether to compile FlashInfer AOT kernels # Set to "true" to enable AOT compilation: # docker build --build-arg FLASHINFER_AOT_COMPILE=true ...