Skip to content

Commit d979dd6

Browse files
authored
[Feature][EPLB] Add eplb support for Qwen3 (#20815)
Signed-off-by: aladerran <[email protected]>
1 parent b876860 commit d979dd6

File tree

1 file changed

+142
-24
lines changed

1 file changed

+142
-24
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 142 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25-
from collections.abc import Iterable
25+
import typing
26+
from collections.abc import Callable, Iterable
2627
from typing import Any, Optional, Union
2728

2829
import torch
@@ -31,8 +32,9 @@
3132

3233
from vllm.attention import Attention
3334
from vllm.compilation.decorators import support_torch_compile
34-
from vllm.config import CacheConfig, VllmConfig
35-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
36+
from vllm.distributed import (get_ep_group, get_pp_group,
37+
get_tensor_model_parallel_world_size)
3638
from vllm.logger import init_logger
3739
from vllm.model_executor.layers.activation import SiluAndMul
3840
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -50,8 +52,8 @@
5052
from vllm.model_executor.sampling_metadata import SamplingMetadata
5153
from vllm.sequence import IntermediateTensors
5254

53-
from .interfaces import SupportsLoRA, SupportsPP
54-
from .utils import (AutoWeightsLoader, extract_layer_index,
55+
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
56+
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5557
is_pp_missing_parameter,
5658
make_empty_intermediate_tensors_factory, make_layers,
5759
maybe_prefix)
@@ -101,23 +103,47 @@ def __init__(
101103
config: PretrainedConfig,
102104
quant_config: Optional[QuantizationConfig] = None,
103105
prefix: str = "",
106+
enable_eplb: bool = False,
104107
):
105108
super().__init__()
106109
self.tp_size = get_tensor_model_parallel_world_size()
107110

111+
self.ep_group = get_ep_group().device_group
112+
self.ep_rank = self.ep_group.rank()
113+
self.ep_size = self.ep_group.size()
114+
self.n_routed_experts = config.num_experts
115+
108116
if self.tp_size > config.num_experts:
109117
raise ValueError(
110118
f"Tensor parallel size {self.tp_size} is greater than "
111119
f"the number of experts {config.num_experts}.")
112120

113-
self.experts = FusedMoE(num_experts=config.num_experts,
121+
# Load balancing settings.
122+
vllm_config = get_current_vllm_config()
123+
parallel_config = vllm_config.parallel_config
124+
self.enable_eplb = enable_eplb
125+
126+
self.n_logical_experts = self.n_routed_experts
127+
self.n_redundant_experts = parallel_config.num_redundant_experts
128+
self.n_physical_experts = (self.n_logical_experts +
129+
self.n_redundant_experts)
130+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
131+
132+
self.physical_expert_start = (self.ep_rank *
133+
self.n_local_physical_experts)
134+
self.physical_expert_end = (self.physical_expert_start +
135+
self.n_local_physical_experts)
136+
137+
self.experts = FusedMoE(num_experts=self.n_routed_experts,
114138
top_k=config.num_experts_per_tok,
115139
hidden_size=config.hidden_size,
116140
intermediate_size=config.moe_intermediate_size,
117141
reduce_results=False,
118142
renormalize=config.norm_topk_prob,
119143
quant_config=quant_config,
120-
prefix=f"{prefix}.experts")
144+
prefix=f"{prefix}.experts",
145+
enable_eplb=self.enable_eplb,
146+
num_redundant_experts=self.n_redundant_experts)
121147

122148
self.gate = ReplicatedLinear(config.hidden_size,
123149
config.num_experts,
@@ -246,6 +272,7 @@ def __init__(
246272
cache_config: Optional[CacheConfig] = None,
247273
quant_config: Optional[QuantizationConfig] = None,
248274
prefix: str = "",
275+
enable_eplb: bool = False,
249276
) -> None:
250277
super().__init__()
251278
self.hidden_size = config.hidden_size
@@ -277,7 +304,8 @@ def __init__(
277304
(layer_idx + 1) % config.decoder_sparse_step == 0):
278305
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
279306
quant_config=quant_config,
280-
prefix=f"{prefix}.mlp")
307+
prefix=f"{prefix}.mlp",
308+
enable_eplb=enable_eplb)
281309
else:
282310
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
283311
intermediate_size=config.intermediate_size,
@@ -323,6 +351,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
323351
config = vllm_config.model_config.hf_config
324352
cache_config = vllm_config.cache_config
325353
quant_config = vllm_config.quant_config
354+
parallel_config = vllm_config.parallel_config
355+
enable_eplb = parallel_config.enable_eplb
356+
self.num_redundant_experts = parallel_config.num_redundant_experts
326357

327358
self.padding_idx = config.pad_token_id
328359
self.vocab_size = config.vocab_size
@@ -336,7 +367,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336367
lambda prefix: Qwen3MoeDecoderLayer(config=config,
337368
cache_config=cache_config,
338369
quant_config=quant_config,
339-
prefix=prefix),
370+
prefix=prefix,
371+
enable_eplb=enable_eplb),
340372
prefix=f"{prefix}.layers",
341373
)
342374
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -382,7 +414,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
382414
ckpt_gate_proj_name="gate_proj",
383415
ckpt_down_proj_name="down_proj",
384416
ckpt_up_proj_name="up_proj",
385-
num_experts=self.config.num_experts)
417+
num_experts=self.config.num_experts,
418+
num_redundant_experts=self.num_redundant_experts)
386419

387420
def load_weights(self, weights: Iterable[tuple[str,
388421
torch.Tensor]]) -> set[str]:
@@ -433,27 +466,51 @@ def load_weights(self, weights: Iterable[tuple[str,
433466
weight_loader(param, loaded_weight, shard_id)
434467
break
435468
else:
469+
is_expert_weight = False
436470
for mapping in expert_params_mapping:
437471
param_name, weight_name, expert_id, shard_id = mapping
438472
if weight_name not in name:
439473
continue
440-
name = name.replace(weight_name, param_name)
441-
# Skip layers on other devices.
442-
if is_pp_missing_parameter(name, self):
474+
475+
# Anyway, this is an expert weight and should not be
476+
# attempted to load as other weights later
477+
is_expert_weight = True
478+
479+
# Do not modify `name` since the loop may continue here
480+
# Instead, create a new variable
481+
name_mapped = name.replace(weight_name, param_name)
482+
483+
if is_pp_missing_parameter(name_mapped, self):
443484
continue
485+
444486
# Skip loading extra parameters for GPTQ/modelopt models.
445-
if name.endswith(
446-
ignore_suffixes) and name not in params_dict:
487+
if name_mapped.endswith(
488+
ignore_suffixes
489+
) and name_mapped not in params_dict:
447490
continue
448-
param = params_dict[name]
449-
weight_loader = param.weight_loader
450-
weight_loader(param,
451-
loaded_weight,
452-
name,
453-
shard_id=shard_id,
454-
expert_id=expert_id)
455-
break
491+
492+
param = params_dict[name_mapped]
493+
# We should ask the weight loader to return success or not
494+
# here since otherwise we may skip experts with other
495+
# available replicas.
496+
weight_loader = typing.cast(Callable[..., bool],
497+
param.weight_loader)
498+
success = weight_loader(param,
499+
loaded_weight,
500+
name_mapped,
501+
shard_id=shard_id,
502+
expert_id=expert_id,
503+
return_success=True)
504+
if success:
505+
name = name_mapped
506+
break
456507
else:
508+
if is_expert_weight:
509+
# We've checked that this is an expert weight
510+
# However it's not mapped locally to this rank
511+
# So we simply skip it
512+
continue
513+
457514
# Skip loading extra parameters for GPTQ/modelopt models.
458515
if name.endswith(
459516
ignore_suffixes) and name not in params_dict:
@@ -482,7 +539,8 @@ def load_weights(self, weights: Iterable[tuple[str,
482539
return loaded_params
483540

484541

485-
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
542+
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
543+
MixtureOfExperts):
486544
packed_modules_mapping = {
487545
"qkv_proj": [
488546
"q_proj",
@@ -514,6 +572,66 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514572
self.make_empty_intermediate_tensors = (
515573
self.model.make_empty_intermediate_tensors)
516574

575+
# Set MoE hyperparameters
576+
self.expert_weights = []
577+
578+
self.moe_layers: list[FusedMoE] = []
579+
example_layer = None
580+
for layer in self.model.layers:
581+
if isinstance(layer, PPMissingLayer):
582+
continue
583+
584+
assert isinstance(layer, Qwen3MoeDecoderLayer)
585+
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
586+
example_layer = layer.mlp
587+
self.moe_layers.append(layer.mlp.experts)
588+
589+
if example_layer is None:
590+
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
591+
592+
self.num_moe_layers = len(self.moe_layers)
593+
self.num_expert_groups = 1
594+
self.num_shared_experts = 0
595+
self.num_logical_experts = example_layer.n_logical_experts
596+
self.num_physical_experts = example_layer.n_physical_experts
597+
self.num_local_physical_experts = example_layer.n_local_physical_experts
598+
self.num_routed_experts = example_layer.n_routed_experts
599+
self.num_redundant_experts = example_layer.n_redundant_experts
600+
601+
def set_eplb_state(
602+
self,
603+
expert_load_view: torch.Tensor,
604+
logical_to_physical_map: torch.Tensor,
605+
logical_replica_count: torch.Tensor,
606+
) -> None:
607+
for layer_idx, layer in enumerate(self.moe_layers):
608+
# Register the expert weights.
609+
self.expert_weights.append(layer.get_expert_weights())
610+
layer.set_eplb_state(
611+
moe_layer_idx=layer_idx,
612+
expert_load_view=expert_load_view,
613+
logical_to_physical_map=logical_to_physical_map,
614+
logical_replica_count=logical_replica_count,
615+
)
616+
617+
def update_physical_experts_metadata(
618+
self,
619+
num_physical_experts: int,
620+
num_local_physical_experts: int,
621+
) -> None:
622+
assert self.num_local_physical_experts == num_local_physical_experts
623+
self.num_physical_experts = num_physical_experts
624+
self.num_local_physical_experts = num_local_physical_experts
625+
self.num_redundant_experts = (num_physical_experts -
626+
self.num_logical_experts)
627+
for layer in self.model.layers:
628+
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
629+
moe = layer.mlp
630+
moe.n_local_physical_experts = num_local_physical_experts
631+
moe.n_physical_experts = num_physical_experts
632+
moe.n_redundant_experts = self.num_redundant_experts
633+
moe.experts.update_expert_map()
634+
517635
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
518636
return self.model.get_input_embeddings(input_ids)
519637

0 commit comments

Comments
 (0)