Skip to content

Commit 6108946

Browse files
authored
[Model] Add MoE support for NemotronH (#25863)
Signed-off-by: Tomer Asida <[email protected]>
1 parent 88afa11 commit 6108946

File tree

7 files changed

+413
-39
lines changed

7 files changed

+413
-39
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,8 @@ class FusedMoEConfig:
823823

824824
has_bias: bool = False
825825

826+
is_act_and_mul: bool = True
827+
826828
def __post_init__(self):
827829
if self.dp_size > 1:
828830
logger.debug_once(

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,7 @@ def fused_experts(
16471647

16481648
SILU_NO_MUL: str = activation_without_mul("silu")
16491649
GELU_NO_MUL: str = activation_without_mul("gelu")
1650+
RELU2_NO_MUL: str = activation_without_mul("relu2")
16501651

16511652

16521653
def _get_config_quant_dtype(
@@ -1914,7 +1915,8 @@ def fused_experts_impl(
19141915
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
19151916
elif activation == GELU_NO_MUL:
19161917
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
1917-
1918+
elif activation == RELU2_NO_MUL:
1919+
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
19181920
else:
19191921
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
19201922

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,15 @@ def create_weights(
411411
params_dtype: torch.dtype,
412412
**extra_weight_attrs,
413413
):
414+
if self.moe.is_act_and_mul:
415+
w13_up_dim = 2 * intermediate_size_per_partition
416+
else:
417+
w13_up_dim = intermediate_size_per_partition
414418
# Fused gate_up_proj (column parallel)
415419
w13_weight = torch.nn.Parameter(
416420
torch.empty(
417421
num_experts,
418-
2 * intermediate_size_per_partition,
422+
w13_up_dim,
419423
hidden_size,
420424
dtype=params_dtype,
421425
),
@@ -425,9 +429,7 @@ def create_weights(
425429
set_weight_attrs(w13_weight, extra_weight_attrs)
426430
if self.moe.has_bias:
427431
w13_bias = torch.nn.Parameter(
428-
torch.zeros(
429-
num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype
430-
),
432+
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
431433
requires_grad=False,
432434
)
433435
layer.register_parameter("w13_bias", w13_bias)
@@ -1073,6 +1075,7 @@ def __init__(
10731075
e_score_correction_bias: torch.Tensor | None = None,
10741076
apply_router_weight_on_input: bool = False,
10751077
activation: str = "silu",
1078+
is_act_and_mul: bool = True,
10761079
enable_eplb: bool = False,
10771080
num_redundant_experts: int = 0,
10781081
has_bias: bool = False,
@@ -1263,6 +1266,7 @@ def __init__(
12631266
in_dtype=moe_in_dtype,
12641267
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
12651268
has_bias=has_bias,
1269+
is_act_and_mul=is_act_and_mul,
12661270
)
12671271
self.moe_config = moe
12681272
self.moe_quant_config: FusedMoEQuantConfig | None = None
@@ -1283,6 +1287,24 @@ def __init__(
12831287
assert isinstance(quant_method, FusedMoEMethodBase)
12841288
self.quant_method = quant_method
12851289

1290+
if not self.moe_config.is_act_and_mul:
1291+
# Avoid circular import
1292+
from vllm.model_executor.layers.quantization.modelopt import (
1293+
ModelOptFp8MoEMethod,
1294+
)
1295+
1296+
if not isinstance(
1297+
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
1298+
):
1299+
raise NotImplementedError(
1300+
"is_act_and_mul=False is supported only for unquantized "
1301+
"and ModelOpt FP8 moe for now"
1302+
)
1303+
if not current_platform.is_cuda():
1304+
raise NotImplementedError(
1305+
"is_act_and_mul=False is supported only for CUDA for now"
1306+
)
1307+
12861308
if self.enable_eplb:
12871309
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
12881310

@@ -1531,7 +1553,10 @@ def _load_w13(
15311553
):
15321554
# Index the loaded weight for tp sharding.
15331555
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
1534-
shard_size = expert_data.shape[shard_dim] // 2
1556+
if self.moe_config.is_act_and_mul:
1557+
shard_size = expert_data.shape[shard_dim] // 2
1558+
else:
1559+
shard_size = expert_data.shape[shard_dim]
15351560
if not load_full:
15361561
loaded_weight = loaded_weight.narrow(
15371562
shard_dim, shard_size * tp_rank, shard_size

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,11 @@ def __init__(
354354

355355
self.cutlass_fp8_supported = cutlass_fp8_supported()
356356
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
357-
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
357+
if (
358+
envs.VLLM_USE_FLASHINFER_MOE_FP8
359+
and has_flashinfer_moe()
360+
and self.moe.is_act_and_mul
361+
):
358362
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
359363
logger.info_once(
360364
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
@@ -405,10 +409,15 @@ def create_weights(
405409
)
406410
weight_loader = extra_weight_attrs.get("weight_loader")
407411

412+
if self.moe.is_act_and_mul:
413+
w13_up_dim = 2 * intermediate_size_per_partition
414+
else:
415+
w13_up_dim = intermediate_size_per_partition
416+
408417
w13_weight = ModelWeightParameter(
409418
data=torch.empty(
410419
num_experts,
411-
2 * intermediate_size_per_partition,
420+
w13_up_dim,
412421
hidden_size,
413422
dtype=weight_dtype,
414423
),
@@ -433,11 +442,16 @@ def create_weights(
433442

434443
if self.quant_config.is_checkpoint_fp8_serialized:
435444
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
436-
# Allocate 2 scales for w1 and w3 respectively.
445+
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
437446
# They will be combined to a single scale after weight loading.
447+
# For non-gated MoE, allocate 1 scale for w13.
448+
if self.moe.is_act_and_mul:
449+
w13_weight_scale_shape = (num_experts, 2)
450+
else:
451+
w13_weight_scale_shape = (num_experts, 1)
438452
w13_weight_scale = PerTensorScaleParameter(
439453
data=torch.full(
440-
(num_experts, 2),
454+
w13_weight_scale_shape,
441455
1.0,
442456
dtype=torch.float32,
443457
),
@@ -485,7 +499,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
485499
# Fp8 moe kernel needs single weight scale for w13 per expert.
486500
# We take the max of the w1 and w3 scales
487501
# then dequant and requant each expert.
488-
if layer.w13_weight_scale.dim() == 2:
502+
if (
503+
layer.w13_weight_scale.dim() == 2
504+
and layer.w13_weight_scale.shape[1] == 2
505+
):
506+
assert self.moe.is_act_and_mul, (
507+
"w13_weight_scale should have 2 elements per expert "
508+
"only for gated MoE"
509+
)
489510
# Get the maximum scale across w1 and w3 for each expert
490511
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
491512

vllm/model_executor/models/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,9 @@ def update_physical_experts_metadata(
673673

674674

675675
def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
676-
return isinstance(model, MixtureOfExperts)
676+
return (
677+
isinstance(model, MixtureOfExperts) and getattr(model, "num_moe_layers", 0) > 0
678+
)
677679

678680

679681
@runtime_checkable

0 commit comments

Comments
 (0)