22
22
# See the License for the specific language governing permissions and
23
23
# limitations under the License.
24
24
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25
- from collections .abc import Iterable
25
+ import typing
26
+ from collections .abc import Callable , Iterable
26
27
from typing import Any , Optional , Union
27
28
28
29
import torch
31
32
32
33
from vllm .attention import Attention
33
34
from vllm .compilation .decorators import support_torch_compile
34
- from vllm .config import CacheConfig , VllmConfig
35
- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
35
+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
36
+ from vllm .distributed import (get_ep_group , get_pp_group ,
37
+ get_tensor_model_parallel_world_size )
36
38
from vllm .logger import init_logger
37
39
from vllm .model_executor .layers .activation import SiluAndMul
38
40
from vllm .model_executor .layers .fused_moe import FusedMoE
50
52
from vllm .model_executor .sampling_metadata import SamplingMetadata
51
53
from vllm .sequence import IntermediateTensors
52
54
53
- from .interfaces import SupportsLoRA , SupportsPP
54
- from .utils import (AutoWeightsLoader , extract_layer_index ,
55
+ from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
56
+ from .utils import (AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
55
57
is_pp_missing_parameter ,
56
58
make_empty_intermediate_tensors_factory , make_layers ,
57
59
maybe_prefix )
@@ -101,23 +103,47 @@ def __init__(
101
103
config : PretrainedConfig ,
102
104
quant_config : Optional [QuantizationConfig ] = None ,
103
105
prefix : str = "" ,
106
+ enable_eplb : bool = False ,
104
107
):
105
108
super ().__init__ ()
106
109
self .tp_size = get_tensor_model_parallel_world_size ()
107
110
111
+ self .ep_group = get_ep_group ().device_group
112
+ self .ep_rank = self .ep_group .rank ()
113
+ self .ep_size = self .ep_group .size ()
114
+ self .n_routed_experts = config .num_experts
115
+
108
116
if self .tp_size > config .num_experts :
109
117
raise ValueError (
110
118
f"Tensor parallel size { self .tp_size } is greater than "
111
119
f"the number of experts { config .num_experts } ." )
112
120
113
- self .experts = FusedMoE (num_experts = config .num_experts ,
121
+ # Load balancing settings.
122
+ vllm_config = get_current_vllm_config ()
123
+ parallel_config = vllm_config .parallel_config
124
+ self .enable_eplb = enable_eplb
125
+
126
+ self .n_logical_experts = self .n_routed_experts
127
+ self .n_redundant_experts = parallel_config .num_redundant_experts
128
+ self .n_physical_experts = (self .n_logical_experts +
129
+ self .n_redundant_experts )
130
+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
131
+
132
+ self .physical_expert_start = (self .ep_rank *
133
+ self .n_local_physical_experts )
134
+ self .physical_expert_end = (self .physical_expert_start +
135
+ self .n_local_physical_experts )
136
+
137
+ self .experts = FusedMoE (num_experts = self .n_routed_experts ,
114
138
top_k = config .num_experts_per_tok ,
115
139
hidden_size = config .hidden_size ,
116
140
intermediate_size = config .moe_intermediate_size ,
117
141
reduce_results = False ,
118
142
renormalize = config .norm_topk_prob ,
119
143
quant_config = quant_config ,
120
- prefix = f"{ prefix } .experts" )
144
+ prefix = f"{ prefix } .experts" ,
145
+ enable_eplb = self .enable_eplb ,
146
+ num_redundant_experts = self .n_redundant_experts )
121
147
122
148
self .gate = ReplicatedLinear (config .hidden_size ,
123
149
config .num_experts ,
@@ -246,6 +272,7 @@ def __init__(
246
272
cache_config : Optional [CacheConfig ] = None ,
247
273
quant_config : Optional [QuantizationConfig ] = None ,
248
274
prefix : str = "" ,
275
+ enable_eplb : bool = False ,
249
276
) -> None :
250
277
super ().__init__ ()
251
278
self .hidden_size = config .hidden_size
@@ -277,7 +304,8 @@ def __init__(
277
304
(layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
278
305
self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
279
306
quant_config = quant_config ,
280
- prefix = f"{ prefix } .mlp" )
307
+ prefix = f"{ prefix } .mlp" ,
308
+ enable_eplb = enable_eplb )
281
309
else :
282
310
self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
283
311
intermediate_size = config .intermediate_size ,
@@ -323,6 +351,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
323
351
config = vllm_config .model_config .hf_config
324
352
cache_config = vllm_config .cache_config
325
353
quant_config = vllm_config .quant_config
354
+ parallel_config = vllm_config .parallel_config
355
+ enable_eplb = parallel_config .enable_eplb
356
+ self .num_redundant_experts = parallel_config .num_redundant_experts
326
357
327
358
self .padding_idx = config .pad_token_id
328
359
self .vocab_size = config .vocab_size
@@ -336,7 +367,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336
367
lambda prefix : Qwen3MoeDecoderLayer (config = config ,
337
368
cache_config = cache_config ,
338
369
quant_config = quant_config ,
339
- prefix = prefix ),
370
+ prefix = prefix ,
371
+ enable_eplb = enable_eplb ),
340
372
prefix = f"{ prefix } .layers" ,
341
373
)
342
374
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -382,7 +414,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
382
414
ckpt_gate_proj_name = "gate_proj" ,
383
415
ckpt_down_proj_name = "down_proj" ,
384
416
ckpt_up_proj_name = "up_proj" ,
385
- num_experts = self .config .num_experts )
417
+ num_experts = self .config .num_experts ,
418
+ num_redundant_experts = self .num_redundant_experts )
386
419
387
420
def load_weights (self , weights : Iterable [tuple [str ,
388
421
torch .Tensor ]]) -> set [str ]:
@@ -433,27 +466,51 @@ def load_weights(self, weights: Iterable[tuple[str,
433
466
weight_loader (param , loaded_weight , shard_id )
434
467
break
435
468
else :
469
+ is_expert_weight = False
436
470
for mapping in expert_params_mapping :
437
471
param_name , weight_name , expert_id , shard_id = mapping
438
472
if weight_name not in name :
439
473
continue
440
- name = name .replace (weight_name , param_name )
441
- # Skip layers on other devices.
442
- if is_pp_missing_parameter (name , self ):
474
+
475
+ # Anyway, this is an expert weight and should not be
476
+ # attempted to load as other weights later
477
+ is_expert_weight = True
478
+
479
+ # Do not modify `name` since the loop may continue here
480
+ # Instead, create a new variable
481
+ name_mapped = name .replace (weight_name , param_name )
482
+
483
+ if is_pp_missing_parameter (name_mapped , self ):
443
484
continue
485
+
444
486
# Skip loading extra parameters for GPTQ/modelopt models.
445
- if name .endswith (
446
- ignore_suffixes ) and name not in params_dict :
487
+ if name_mapped .endswith (
488
+ ignore_suffixes
489
+ ) and name_mapped not in params_dict :
447
490
continue
448
- param = params_dict [name ]
449
- weight_loader = param .weight_loader
450
- weight_loader (param ,
451
- loaded_weight ,
452
- name ,
453
- shard_id = shard_id ,
454
- expert_id = expert_id )
455
- break
491
+
492
+ param = params_dict [name_mapped ]
493
+ # We should ask the weight loader to return success or not
494
+ # here since otherwise we may skip experts with other
495
+ # available replicas.
496
+ weight_loader = typing .cast (Callable [..., bool ],
497
+ param .weight_loader )
498
+ success = weight_loader (param ,
499
+ loaded_weight ,
500
+ name_mapped ,
501
+ shard_id = shard_id ,
502
+ expert_id = expert_id ,
503
+ return_success = True )
504
+ if success :
505
+ name = name_mapped
506
+ break
456
507
else :
508
+ if is_expert_weight :
509
+ # We've checked that this is an expert weight
510
+ # However it's not mapped locally to this rank
511
+ # So we simply skip it
512
+ continue
513
+
457
514
# Skip loading extra parameters for GPTQ/modelopt models.
458
515
if name .endswith (
459
516
ignore_suffixes ) and name not in params_dict :
@@ -482,7 +539,8 @@ def load_weights(self, weights: Iterable[tuple[str,
482
539
return loaded_params
483
540
484
541
485
- class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ):
542
+ class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ,
543
+ MixtureOfExperts ):
486
544
packed_modules_mapping = {
487
545
"qkv_proj" : [
488
546
"q_proj" ,
@@ -514,6 +572,66 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514
572
self .make_empty_intermediate_tensors = (
515
573
self .model .make_empty_intermediate_tensors )
516
574
575
+ # Set MoE hyperparameters
576
+ self .expert_weights = []
577
+
578
+ self .moe_layers : list [FusedMoE ] = []
579
+ example_layer = None
580
+ for layer in self .model .layers :
581
+ if isinstance (layer , PPMissingLayer ):
582
+ continue
583
+
584
+ assert isinstance (layer , Qwen3MoeDecoderLayer )
585
+ if isinstance (layer .mlp , Qwen3MoeSparseMoeBlock ):
586
+ example_layer = layer .mlp
587
+ self .moe_layers .append (layer .mlp .experts )
588
+
589
+ if example_layer is None :
590
+ raise RuntimeError ("No Qwen3MoE layer found in the model.layers." )
591
+
592
+ self .num_moe_layers = len (self .moe_layers )
593
+ self .num_expert_groups = 1
594
+ self .num_shared_experts = 0
595
+ self .num_logical_experts = example_layer .n_logical_experts
596
+ self .num_physical_experts = example_layer .n_physical_experts
597
+ self .num_local_physical_experts = example_layer .n_local_physical_experts
598
+ self .num_routed_experts = example_layer .n_routed_experts
599
+ self .num_redundant_experts = example_layer .n_redundant_experts
600
+
601
+ def set_eplb_state (
602
+ self ,
603
+ expert_load_view : torch .Tensor ,
604
+ logical_to_physical_map : torch .Tensor ,
605
+ logical_replica_count : torch .Tensor ,
606
+ ) -> None :
607
+ for layer_idx , layer in enumerate (self .moe_layers ):
608
+ # Register the expert weights.
609
+ self .expert_weights .append (layer .get_expert_weights ())
610
+ layer .set_eplb_state (
611
+ moe_layer_idx = layer_idx ,
612
+ expert_load_view = expert_load_view ,
613
+ logical_to_physical_map = logical_to_physical_map ,
614
+ logical_replica_count = logical_replica_count ,
615
+ )
616
+
617
+ def update_physical_experts_metadata (
618
+ self ,
619
+ num_physical_experts : int ,
620
+ num_local_physical_experts : int ,
621
+ ) -> None :
622
+ assert self .num_local_physical_experts == num_local_physical_experts
623
+ self .num_physical_experts = num_physical_experts
624
+ self .num_local_physical_experts = num_local_physical_experts
625
+ self .num_redundant_experts = (num_physical_experts -
626
+ self .num_logical_experts )
627
+ for layer in self .model .layers :
628
+ if isinstance (layer .mlp , Qwen3MoeSparseMoeBlock ):
629
+ moe = layer .mlp
630
+ moe .n_local_physical_experts = num_local_physical_experts
631
+ moe .n_physical_experts = num_physical_experts
632
+ moe .n_redundant_experts = self .num_redundant_experts
633
+ moe .experts .update_expert_map ()
634
+
517
635
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
518
636
return self .model .get_input_embeddings (input_ids )
519
637
0 commit comments