2323"""Inference-only Mixtral model."""
2424from typing import List , Optional , Tuple
2525
26- import numpy as np
27-
2826import torch
2927import torch .nn .functional as F
3028
3331
3432from vllm .model_executor .input_metadata import InputMetadata
3533from vllm .model_executor .layers .attention import PagedAttention
34+ from vllm .model_executor .layers .fused_moe import fused_moe
3635from vllm .model_executor .layers .layernorm import RMSNorm
3736from vllm .model_executor .layers .linear import (LinearMethodBase ,
38- ReplicatedLinear ,
3937 QKVParallelLinear ,
38+ ReplicatedLinear ,
4039 RowParallelLinear )
4140from vllm .model_executor .layers .rotary_embedding import get_rope
4241from vllm .model_executor .layers .sampler import Sampler
4746from vllm .model_executor .parallel_utils .parallel_state import (
4847 get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size )
4948from vllm .model_executor .sampling_metadata import SamplingMetadata
49+ from vllm .model_executor .utils import set_weight_attrs
5050from vllm .model_executor .weight_utils import (default_weight_loader ,
5151 hf_model_weights_iterator )
5252from vllm .sequence import SamplerOutput
5353
5454KVCache = Tuple [torch .Tensor , torch .Tensor ]
5555
5656
57- class MixtralMLP (nn .Module ):
57+ class MixtralMoE (nn .Module ):
58+ """A tensor-parallel MoE implementation for Mixtral that shards each expert
59+ across all ranks.
60+
61+ Each expert's weights are sharded across all ranks and a fused MoE
62+ kernel is used for the forward pass, and finally we reduce the outputs
63+ across ranks.
64+ """
5865
5966 def __init__ (
6067 self ,
6168 num_experts : int ,
69+ top_k : int ,
6270 hidden_size : int ,
6371 intermediate_size : int ,
64- linear_method : Optional [LinearMethodBase ] = None ,
65- ) -> None :
72+ params_dtype : Optional [torch . dtype ] = None ,
73+ ):
6674 super ().__init__ ()
67- self .num_experts = num_experts
68- self .ffn_dim = intermediate_size
69- self .hidden_dim = hidden_size
70-
71- self .w1 = ReplicatedLinear (self .hidden_dim ,
72- self .ffn_dim ,
73- bias = False ,
74- linear_method = linear_method )
75- self .w2 = ReplicatedLinear (self .ffn_dim ,
76- self .hidden_dim ,
77- bias = False ,
78- linear_method = linear_method )
79- self .w3 = ReplicatedLinear (self .hidden_dim ,
80- self .ffn_dim ,
81- bias = False ,
82- linear_method = linear_method )
83-
84- # TODO: Use vllm's SiluAndMul
85- self .act_fn = nn .SiLU ()
86-
87- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
88- w1_out , _ = self .w1 (hidden_states )
89- w1_out = self .act_fn (w1_out )
90- w3_out , _ = self .w3 (hidden_states )
91- current_hidden_states = w1_out * w3_out
92- current_hidden_states , _ = self .w2 (current_hidden_states )
93- return current_hidden_states
94-
75+ tp_size = get_tensor_model_parallel_world_size ()
76+ self .num_total_experts = num_experts
77+ self .top_k = top_k
78+ self .hidden_size = hidden_size
79+ self .intermediate_size = intermediate_size // tp_size
9580
96- class MixtralMoE (nn .Module ):
81+ if params_dtype is None :
82+ params_dtype = torch .get_default_dtype ()
83+ self .params_dtype = params_dtype
9784
98- def __init__ (
99- self ,
100- config : MixtralConfig ,
101- linear_method : Optional [LinearMethodBase ] = None ,
102- ):
103- super ().__init__ ()
104- self .config = config
105- self .rank = get_tensor_model_parallel_rank ()
106- self .tp_size = get_tensor_model_parallel_world_size ()
107- self .num_total_experts = config .num_local_experts
108- self .top_k = config .num_experts_per_tok
109- if self .tp_size > self .num_total_experts :
110- raise ValueError (
111- f"Tensor parallel size { self .tp_size } is greater than "
112- f"the number of experts { self .num_total_experts } ." )
113- # Split experts equally between ranks
114- self .expert_indicies = np .array_split (range (
115- self .num_total_experts ), self .tp_size )[self .rank ].tolist ()
116- if not self .expert_indicies :
117- raise ValueError (
118- f"Rank { self .rank } has no experts assigned to it." )
119-
120- self .experts = nn .ModuleList ([
121- MixtralMLP (self .num_total_experts ,
122- config .hidden_size ,
123- config .intermediate_size ,
124- linear_method = linear_method )
125- if idx in self .expert_indicies else None
126- for idx in range (self .num_total_experts )
127- ])
128- self .gate = ReplicatedLinear (config .hidden_size ,
85+ self .gate = ReplicatedLinear (self .hidden_size ,
12986 self .num_total_experts ,
13087 bias = False ,
88+ params_dtype = self .params_dtype ,
13189 linear_method = None )
13290
91+ self .ws = nn .Parameter (
92+ torch .empty (self .num_total_experts ,
93+ 2 * self .intermediate_size ,
94+ self .hidden_size ,
95+ device = "cuda" ,
96+ dtype = self .params_dtype ))
97+ self .w2s = nn .Parameter (
98+ torch .empty (self .num_total_experts ,
99+ self .hidden_size ,
100+ self .intermediate_size ,
101+ device = "cuda" ,
102+ dtype = self .params_dtype ))
103+
104+ set_weight_attrs (self .ws , {
105+ "weight_loader" : self .weight_loader ,
106+ })
107+ set_weight_attrs (self .w2s , {
108+ "weight_loader" : self .weight_loader ,
109+ })
110+
111+ def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ,
112+ weight_name : str , expert_id : int ):
113+ tp_rank = get_tensor_model_parallel_rank ()
114+ param_data = param .data
115+ shard_size = self .intermediate_size
116+ shard = slice (tp_rank * shard_size , (tp_rank + 1 ) * shard_size )
117+ if weight_name .endswith ("w1.weight" ):
118+ param_data [expert_id , 0 :shard_size , :] = loaded_weight [shard , :]
119+ if weight_name .endswith ("w3.weight" ):
120+ param_data [expert_id ,
121+ shard_size :2 * shard_size , :] = loaded_weight [shard , :]
122+ if weight_name .endswith ("w2.weight" ):
123+ param_data [expert_id , :, :] = loaded_weight [:, shard ]
124+
133125 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
134- batch_size , sequence_length , hidden_dim = hidden_states .shape
135- hidden_states = hidden_states .view (- 1 , hidden_dim )
126+ batch_size , sequence_length , hidden_size = hidden_states .shape
127+ hidden_states = hidden_states .view (- 1 , self . hidden_size )
136128 # router_logits: (batch * sequence_length, n_experts)
137129 router_logits , _ = self .gate (hidden_states )
138130
@@ -142,22 +134,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142134 dim = - 1 )
143135 routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
144136
145- final_hidden_states = None
146- for expert_idx in self .expert_indicies :
147- expert_layer = self .experts [expert_idx ]
148- expert_mask = (selected_experts == expert_idx )
149- expert_weights = (routing_weights * expert_mask ).sum (dim = - 1 ,
150- keepdim = True )
151-
152- current_hidden_states = expert_layer (hidden_states ).mul_ (
153- expert_weights )
154- if final_hidden_states is None :
155- final_hidden_states = current_hidden_states
156- else :
157- final_hidden_states .add_ (current_hidden_states )
137+ final_hidden_states = fused_moe (hidden_states ,
138+ self .ws ,
139+ self .w2s ,
140+ routing_weights ,
141+ selected_experts ,
142+ inplace = True )
143+
144+ final_hidden_states = tensor_model_parallel_all_reduce (
145+ final_hidden_states )
158146
159- return tensor_model_parallel_all_reduce ( final_hidden_states ) .view (
160- batch_size , sequence_length , hidden_dim )
147+ return final_hidden_states .view (batch_size , sequence_length ,
148+ hidden_size )
161149
162150
163151class MixtralAttention (nn .Module ):
@@ -257,8 +245,11 @@ def __init__(
257245 rope_theta = rope_theta ,
258246 sliding_window = config .sliding_window ,
259247 linear_method = linear_method )
260- self .block_sparse_moe = MixtralMoE (config = config ,
261- linear_method = linear_method )
248+ self .block_sparse_moe = MixtralMoE (
249+ num_experts = config .num_local_experts ,
250+ top_k = config .num_experts_per_tok ,
251+ hidden_size = config .hidden_size ,
252+ intermediate_size = config .intermediate_size )
262253 self .input_layernorm = RMSNorm (config .hidden_size ,
263254 eps = config .rms_norm_eps )
264255 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
@@ -378,6 +369,14 @@ def load_weights(self,
378369 ("qkv_proj" , "v_proj" , "v" ),
379370 ]
380371
372+ expert_params_mapping = [
373+ # (param_name, weight_name, expert_id)
374+ ("ws" if weight_name in ["w1" , "w3" ] else "w2s" ,
375+ f"experts.{ expert_id } .{ weight_name } .weight" , expert_id )
376+ for expert_id in range (self .config .num_local_experts )
377+ for weight_name in ["w1" , "w2" , "w3" ]
378+ ]
379+
381380 params_dict = dict (self .named_parameters ())
382381 for name , loaded_weight in hf_model_weights_iterator (
383382 model_name_or_path ,
@@ -387,6 +386,7 @@ def load_weights(self,
387386 fall_back_to_pt = False ):
388387 if "rotary_emb.inv_freq" in name :
389388 continue
389+
390390 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
391391 if weight_name not in name :
392392 continue
@@ -399,14 +399,22 @@ def load_weights(self,
399399 weight_loader (param , loaded_weight , shard_id )
400400 break
401401 else :
402- # Skip loading extra bias for GPTQ models.
403- if name .endswith (".bias" ) and name not in params_dict :
404- continue
405- # Skip experts that are not assigned to this worker.
406- if ("block_sparse_moe.experts." in name
407- and name not in params_dict ):
408- continue
409- param = params_dict [name ]
410- weight_loader = getattr (param , "weight_loader" ,
411- default_weight_loader )
412- weight_loader (param , loaded_weight )
402+ for param_name , weight_name , expert_id in expert_params_mapping :
403+ if weight_name not in name :
404+ continue
405+ name = name .replace (weight_name , param_name )
406+ param = params_dict [name ]
407+ weight_loader = param .weight_loader
408+ weight_loader (param ,
409+ loaded_weight ,
410+ weight_name ,
411+ expert_id = expert_id )
412+ break
413+ else :
414+ # Skip loading extra bias for GPTQ models.
415+ if name .endswith (".bias" ) and name not in params_dict :
416+ continue
417+ param = params_dict [name ]
418+ weight_loader = getattr (param , "weight_loader" ,
419+ default_weight_loader )
420+ weight_loader (param , loaded_weight )
0 commit comments