Skip to content

Commit 0340f45

Browse files
authored
Support expert parallel load balancing in Transformers backend (#26287)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 19a00eb commit 0340f45

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from vllm.config.multimodal import BaseDummyOptions
4242
from vllm.config.utils import getattr_iter
43-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
43+
from vllm.distributed import get_pp_group, get_tp_group
4444
from vllm.distributed.utils import get_pp_indices
4545
from vllm.logger import init_logger
4646
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -506,9 +506,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
506506
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
507507

508508
self.pp_group = get_pp_group()
509-
self.pp_size = self.pp_group.world_size
510-
self.pp_rank = self.pp_group.rank_in_group
511-
self.tp_size = get_tensor_model_parallel_world_size()
509+
self.tp_group = get_tp_group()
512510

513511
# Weights to skip in `self.load_weights`
514512
self.skip_prefixes: list[str] = []
@@ -576,7 +574,7 @@ def pipeline_parallel(self):
576574
"""
577575
Apply the model's pipeline parallelization plan.
578576
"""
579-
if self.pp_size <= 1:
577+
if self.pp_group.world_size <= 1:
580578
return
581579

582580
if not self.model.supports_pp_plan:
@@ -613,7 +611,9 @@ def pipeline_parallel(self):
613611

614612
# Module list
615613
start_layer, end_layer = get_pp_indices(
616-
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
614+
self.text_config.num_hidden_layers,
615+
self.pp_group.rank_in_group,
616+
self.pp_group.world_size,
617617
)
618618
layers_name = pp_plan[module_list_idx]
619619
layers = getattr(self.model, layers_name)
@@ -638,7 +638,7 @@ def recursive_replace(self):
638638
"""
639639
tp_plan = self.model.tp_plan
640640

641-
if not tp_plan and self.tp_size > 1:
641+
if not tp_plan and self.tp_group.world_size > 1:
642642
tip = get_feature_request_tip(
643643
self.model_config.model, self.model_config.trust_remote_code
644644
)
@@ -687,7 +687,9 @@ def create_attention_instances(
687687
head_size = self.model_config.get_head_size()
688688
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
689689
start, end = get_pp_indices(
690-
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
690+
self.text_config.num_hidden_layers,
691+
self.pp_group.rank_in_group,
692+
self.pp_group.world_size,
691693
)
692694

693695
attention_instances = {}
@@ -749,7 +751,7 @@ def forward(
749751
intermediate_tensors: Optional[IntermediateTensors] = None,
750752
inputs_embeds: Optional[torch.Tensor] = None,
751753
) -> Union[torch.Tensor, IntermediateTensors]:
752-
if not get_pp_group().is_first_rank:
754+
if not self.pp_group.is_first_rank:
753755
assert intermediate_tensors is not None
754756
input_ids = None
755757
inputs_embeds = intermediate_tensors["hidden_states"]
@@ -773,7 +775,7 @@ def forward(
773775
return_dict=False,
774776
)[0][0, ...] # we remove batch dimension for now
775777

776-
if not get_pp_group().is_last_rank:
778+
if not self.pp_group.is_last_rank:
777779
return IntermediateTensors({"hidden_states": hidden_states})
778780

779781
return hidden_states
@@ -811,7 +813,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
811813
if self.text_config.tie_word_embeddings:
812814
self.skip_prefixes.append("lm_head.")
813815

814-
if get_pp_group().is_last_rank:
816+
if self.pp_group.is_last_rank:
815817
self.unpadded_vocab_size = self.text_config.vocab_size
816818
self.lm_head = ParallelLMHead(
817819
self.text_config.vocab_size,

vllm/model_executor/models/transformers_moe.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.platforms import current_platform
3131
from vllm.utils import direct_register_custom_op
3232

33+
from .interfaces import MixtureOfExperts
3334
from .transformers import (
3435
TransformersBase,
3536
TransformersForCausalLM,
@@ -116,17 +117,41 @@ def transformers_moe_forward_fake(
116117
)
117118

118119

119-
class TransformersMoEBase(TransformersBase):
120+
class TransformersMoEBase(TransformersBase, MixtureOfExperts):
120121
def __init__(self, *, vllm_config, prefix=""):
121122
self.check_version("4.57.0.dev0", "MoE models support")
123+
self.ep_group = get_ep_group()
122124
super().__init__(vllm_config=vllm_config, prefix=prefix)
123125

124-
if self.parallel_config.enable_eplb:
125-
raise NotImplementedError(
126-
"Transformers backend does not support expert parallel load "
127-
"balancing yet."
126+
def set_eplb_state(
127+
self,
128+
expert_load_view: torch.Tensor,
129+
logical_to_physical_map: torch.Tensor,
130+
logical_replica_count: torch.Tensor,
131+
):
132+
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers):
133+
mlp_layer.experts.set_eplb_state(
134+
moe_layer_idx=moe_layer_idx,
135+
expert_load_view=expert_load_view,
136+
logical_to_physical_map=logical_to_physical_map,
137+
logical_replica_count=logical_replica_count,
128138
)
129139

140+
def update_physical_experts_metadata(
141+
self,
142+
num_physical_experts: int,
143+
num_local_physical_experts: int,
144+
):
145+
assert self.num_local_physical_experts == num_local_physical_experts
146+
self.num_physical_experts = num_physical_experts
147+
self.num_local_physical_experts = num_local_physical_experts
148+
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
149+
for mlp in self.mlp_layers:
150+
mlp.n_local_physical_experts = num_local_physical_experts
151+
mlp.n_physical_experts = num_physical_experts
152+
mlp.n_redundant_experts = self.num_redundant_experts
153+
mlp.experts.update_expert_map()
154+
130155
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
131156
"""
132157
Params for weights, fp8 weight scales, fp8 activation scales
@@ -138,15 +163,17 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
138163
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
139164
("linear", "linear_1", "linear_v"), # Grok1 style
140165
]
166+
num_experts = self.model_config.get_num_experts()
167+
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
141168
expert_mapping = []
142169
for gate_proj, down_proj, up_proj in ckpt_names:
143170
expert_mapping.extend(
144171
FusedMoE.make_expert_params_mapping(
145172
ckpt_gate_proj_name=gate_proj,
146173
ckpt_down_proj_name=down_proj,
147174
ckpt_up_proj_name=up_proj,
148-
num_experts=self.model_config.get_num_experts(),
149-
num_redundant_experts=0, # TODO: enable EPLB
175+
num_experts=num_experts,
176+
num_redundant_experts=num_redundant_experts,
150177
)
151178
)
152179
return expert_mapping
@@ -167,12 +194,15 @@ def recursive_replace(self):
167194

168195
# If there are shared experts, the results are
169196
# reduced after mlp.forward() not inside FusedMoE
170-
num_experts_shared = getattr_iter(
197+
num_shared_experts = getattr_iter(
171198
text_config,
172-
["num_experts_shared", "n_shared_experts", "moe_num_shared_experts"],
199+
[
200+
"n_shared_experts", # DeepSeek, Docs, GLM
201+
"moe_num_shared_experts", # Aria, Ernie
202+
],
173203
0,
174204
)
175-
reduce_results = num_experts_shared == 0
205+
reduce_results = num_shared_experts == 0
176206

177207
def add_all_reduce(mlp: nn.Module):
178208
"""Adds an all-reduce to the output of `mlp.forward()`."""
@@ -207,13 +237,23 @@ def forward(self, *args, **kwargs):
207237
# Expert mapping for `AutoWeightsLoader`
208238
expert_mapping = self.get_expert_mapping()
209239

210-
# Configs
211-
parallel_config = self.parallel_config
212-
eplb_config = parallel_config.eplb_config
213-
214240
# Expert parallel load balancing kwargs
215-
enable_eplb = parallel_config.enable_eplb
216-
num_redundant_experts = eplb_config.num_redundant_experts
241+
enable_eplb = self.parallel_config.enable_eplb
242+
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
243+
244+
# MixtureOfExperts mixin settings
245+
ep_size = self.ep_group.world_size
246+
247+
self.mlp_layers = [] # Used for MixtureOfExperts methods
248+
self.expert_weights = []
249+
self.num_moe_layers = 0
250+
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
251+
self.num_logical_experts = num_experts
252+
self.num_physical_experts = num_experts + num_redundant_experts
253+
self.num_local_physical_experts = self.num_physical_experts // ep_size
254+
self.num_routed_experts = num_experts
255+
self.num_shared_experts = num_shared_experts
256+
self.num_redundant_experts = num_redundant_experts
217257

218258
# Recursively fuse MoE layers
219259
def _recursive_replace(module: nn.Module, prefix: str):
@@ -235,6 +275,9 @@ def _recursive_replace(module: nn.Module, prefix: str):
235275
for mlp_param_name, _ in mlp.named_parameters():
236276
if "shared_expert" in mlp_param_name:
237277
reduce_results = False
278+
# If the config does not specify num_shared_experts, but
279+
# the model has shared experts, we assume there is one.
280+
self.num_shared_experts = 1
238281
break
239282
# Replace experts module with FusedMoE
240283
fused_experts = TransformersFusedMoE(
@@ -258,6 +301,10 @@ def _recursive_replace(module: nn.Module, prefix: str):
258301
)
259302
mlp.experts = fused_experts
260303
log_replacement(qual_name, experts, fused_experts)
304+
# Update MixtureOfExperts mixin state
305+
self.mlp_layers.append(mlp)
306+
self.expert_weights.append(fused_experts.get_expert_weights())
307+
self.num_moe_layers += 1
261308
# If results are not all-reduced in FusedMoE, ensure they
262309
# are all-reduced at the end of mlp.forward() if tensor
263310
# parallel or expert parallel is enabled

0 commit comments

Comments
 (0)