Skip to content

Commit 03ee481

Browse files
authored
Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261)
1 parent 5a87076 commit 03ee481

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

tests/kernels/moe/test_flashinfer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,14 @@ class TestData:
7777

7878
@staticmethod
7979
def make_moe_tensors_8bit(
80-
m: int, k: int, n: int, e: int, reorder: bool
80+
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
8181
) -> "TestData":
82+
is_gated = activation != "relu2_no_mul"
83+
8284
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
83-
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
85+
w13 = torch.randn(
86+
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
87+
)
8488
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
8589

8690
# Scale to fp8
@@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
190194
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
191195
@pytest.mark.parametrize("e", NUM_EXPERTS)
192196
@pytest.mark.parametrize("topk", TOP_KS)
197+
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
193198
def test_flashinfer_cutlass_moe_fp8_no_graph(
194199
m: int,
195200
n: int,
196201
k: int,
197202
e: int,
198203
topk: int,
204+
activation: str,
199205
monkeypatch,
200206
):
201207
current_platform.seed_everything(7)
202208
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
203209
with set_current_vllm_config(vllm_config):
204-
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
210+
td = TestData.make_moe_tensors_8bit(
211+
m, k, n, e, reorder=False, activation=activation
212+
)
205213

206214
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
207215
topk_weights, topk_ids, _ = FusedMoE.select_experts(
@@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
233241
topk_weights=topk_weights,
234242
topk_ids=topk_ids,
235243
inplace=False,
236-
activation="silu",
244+
activation=activation,
237245
global_num_experts=e,
238246
expert_map=None,
239247
apply_router_weight_on_input=True,
@@ -253,7 +261,7 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
253261
td.layer,
254262
topk_weights,
255263
topk_ids,
256-
activation="silu",
264+
activation=activation,
257265
global_num_experts=e,
258266
expert_map=None,
259267
apply_router_weight_on_input=True,

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,14 @@ def apply(
148148
expert_tokens_meta: mk.ExpertTokensMetadata | None,
149149
apply_router_weight_on_input: bool | None,
150150
):
151-
assert activation == "silu", (
152-
"Only activation silu is supported in FlashInferExperts"
151+
from flashinfer.fused_moe.core import ActivationType
152+
153+
activation_str_to_value_map = {
154+
"silu": ActivationType.Swiglu, # This is the default
155+
"relu2_no_mul": ActivationType.Relu2,
156+
}
157+
assert activation in activation_str_to_value_map, (
158+
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
153159
)
154160

155161
# Select quantization metadata based on FP8 format/path
@@ -215,6 +221,7 @@ def apply(
215221
ep_size=self.ep_size,
216222
ep_rank=self.ep_rank,
217223
output=output,
224+
activation_type=activation_str_to_value_map[activation],
218225
# Informs FlashInfer to use the block-scale decoding path when True
219226
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
220227
)

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,18 @@ def __init__(
354354

355355
self.cutlass_fp8_supported = cutlass_fp8_supported()
356356
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
357-
if (
358-
envs.VLLM_USE_FLASHINFER_MOE_FP8
359-
and has_flashinfer_moe()
360-
and self.moe.is_act_and_mul
361-
):
357+
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
362358
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
359+
if (
360+
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
361+
and not self.moe.is_act_and_mul
362+
):
363+
logger.info_once(
364+
"Non-gated MoE is not supported for min-latency mode,"
365+
"falling back to high-throughput mode"
366+
)
367+
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
368+
363369
logger.info_once(
364370
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
365371
)
@@ -557,10 +563,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
557563
)
558564

559565
if self.flashinfer_moe_backend is not None:
560-
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
561-
register_moe_scaling_factors(layer)
566+
if self.moe.is_act_and_mul:
567+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
562568
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
563569
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
570+
register_moe_scaling_factors(layer)
564571

565572
def get_fused_moe_quant_config(
566573
self, layer: torch.nn.Module
@@ -570,13 +577,13 @@ def get_fused_moe_quant_config(
570577

571578
return fp8_w8a8_moe_quant_config(
572579
w1_scale=layer.w13_weight_scale,
573-
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
580+
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
574581
w2_scale=layer.w2_weight_scale,
575-
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
582+
g2_alphas=layer.output2_scales_scalar.squeeze(),
576583
a1_scale=layer.w13_input_scale,
577584
a1_gscale=layer.w13_input_scale,
578585
a2_scale=layer.w2_input_scale,
579-
a2_gscale=1.0 / layer.w2_input_scale,
586+
a2_gscale=layer.w2_input_scale_inv,
580587
per_act_token_quant=False,
581588
)
582589

@@ -642,9 +649,9 @@ def apply(
642649
)
643650

644651
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
645-
assert not renormalize
646-
assert activation == "silu", (
647-
f"Expected 'silu' activation but got {activation}"
652+
assert activation in ("silu", "relu2_no_mul"), (
653+
"Expected activation to be in ('silu', 'relu2_no_mul'),"
654+
f"but got {activation}"
648655
)
649656
return flashinfer_cutlass_moe_fp8(
650657
x,

0 commit comments

Comments
 (0)