1515from vllm .distributed .parallel_state import get_pp_group
1616from vllm .forward_context import get_forward_context
1717from vllm .model_executor .layers .layernorm import RMSNorm
18- from vllm .model_executor .layers .linear import ReplicatedLinear
18+ from vllm .model_executor .layers .linear import (QKVParallelLinear ,
19+ RowParallelLinear )
1920from vllm .model_executor .layers .logits_processor import LogitsProcessor
2021from vllm .model_executor .layers .mamba .mamba2_metadata import (
2122 Mamba2Metadata , prepare_mamba2_metadata )
3637from .granitemoeshared import GraniteMoeSharedMLP
3738from .interfaces import (HasInnerState , IsHybrid , SupportsLoRA , SupportsPP ,
3839 SupportsQuant , SupportsV0Only )
39- from .utils import (AutoWeightsLoader , make_empty_intermediate_tensors_factory ,
40- make_layers , maybe_prefix )
40+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
41+ make_empty_intermediate_tensors_factory , make_layers ,
42+ maybe_prefix )
4143
4244
4345class GraniteMoeHybridMambaDecoderLayer (nn .Module ):
@@ -220,35 +222,37 @@ def __init__(
220222 self .hidden_size = config .hidden_size
221223 self .attention_bias = config .attention_bias
222224 self .attention_multiplier = config .attention_multiplier
223- self .num_heads = config .num_attention_heads
224- self .head_dim = self .hidden_size // self .num_heads
225- self .num_key_value_heads = config .num_key_value_heads
226-
227- self .q_proj = ReplicatedLinear (self .hidden_size ,
228- self .num_heads * self .head_dim ,
229- bias = self .attention_bias ,
230- quant_config = quant_config ,
231- prefix = f"{ prefix } .q_proj" )
232-
233- self .k_proj = ReplicatedLinear (self .hidden_size ,
234- self .num_key_value_heads *
235- self .head_dim ,
236- bias = self .attention_bias ,
237- quant_config = quant_config ,
238- prefix = f"{ prefix } .k_proj" )
239-
240- self .v_proj = ReplicatedLinear (self .hidden_size ,
241- self .num_key_value_heads *
242- self .head_dim ,
243- bias = self .attention_bias ,
244- quant_config = quant_config ,
245- prefix = f"{ prefix } .v_proj" )
246-
247- self .o_proj = ReplicatedLinear (self .hidden_size ,
248- self .hidden_size ,
249- bias = self .attention_bias ,
250- quant_config = quant_config ,
251- prefix = f"{ prefix } .o_proj" )
225+ self .total_num_heads = config .num_attention_heads
226+ self .head_dim = self .hidden_size // self .total_num_heads
227+ self .total_num_kv_heads = config .num_key_value_heads
228+
229+ # TensorParallel logic
230+ tp_size = get_tensor_model_parallel_world_size ()
231+ assert self .total_num_heads % tp_size == 0
232+ self .num_heads = self .total_num_heads // tp_size
233+ if self .total_num_kv_heads >= tp_size :
234+ # Number of KV heads is greater than TP size, so we partition
235+ # the KV heads across multiple tensor parallel GPUs.
236+ assert self .total_num_kv_heads % tp_size == 0
237+ else :
238+ # Number of KV heads is less than TP size, so we replicate
239+ # the KV heads across multiple tensor parallel GPUs.
240+ assert tp_size % self .total_num_kv_heads == 0
241+ self .num_key_value_heads = max (1 , self .total_num_kv_heads // tp_size )
242+
243+ self .qkv_proj = QKVParallelLinear (self .hidden_size ,
244+ self .head_dim ,
245+ self .total_num_heads ,
246+ self .total_num_kv_heads ,
247+ bias = self .attention_bias ,
248+ quant_config = quant_config ,
249+ prefix = f"{ prefix } .qkv_proj" )
250+
251+ self .o_proj = RowParallelLinear (self .hidden_size ,
252+ self .hidden_size ,
253+ bias = self .attention_bias ,
254+ quant_config = quant_config ,
255+ prefix = f"{ prefix } .o_proj" )
252256
253257 if config .position_embedding_type == "rope" :
254258 self .rotary_emb = get_rope (
@@ -278,9 +282,12 @@ def forward(
278282 hidden_states : torch .Tensor ,
279283 ) -> torch .Tensor :
280284
281- query = self .q_proj (hidden_states )[0 ]
282- key = self .k_proj (hidden_states )[0 ]
283- value = self .v_proj (hidden_states )[0 ]
285+ qkv , _ = self .qkv_proj (hidden_states )
286+ query , key , value = qkv .split ([
287+ self .num_heads * self .head_dim , self .num_key_value_heads *
288+ self .head_dim , self .num_key_value_heads * self .head_dim
289+ ],
290+ dim = - 1 )
284291
285292 if self .rotary_emb is not None :
286293 query , key = self .rotary_emb (positions , query , key )
@@ -401,6 +408,12 @@ def forward(
401408
402409 def load_weights (self , weights : Iterable [tuple [str ,
403410 torch .Tensor ]]) -> set [str ]:
411+ stacked_params_mapping = [
412+ # (param_name, shard_name, shard_id)
413+ (".qkv_proj" , ".q_proj" , "q" ),
414+ (".qkv_proj" , ".k_proj" , "k" ),
415+ (".qkv_proj" , ".v_proj" , "v" ),
416+ ]
404417 params_dict = dict (self .named_parameters ())
405418 loaded_params : set [str ] = set ()
406419
@@ -411,6 +424,15 @@ def _load(n, p):
411424 weight_loader (param , p )
412425 loaded_params .add (n )
413426
427+ def _load_shard (n , p , shard_id ):
428+ # Skip layers on other devices.
429+ if not is_pp_missing_parameter (n , self ):
430+ param = params_dict [n ]
431+ weight_loader = getattr (param , "weight_loader" ,
432+ default_weight_loader )
433+ weight_loader (param , p , shard_id )
434+ loaded_params .add (n )
435+
414436 def _load_expert (n , p , name , shard_id , expert_id ):
415437 param = params_dict [n ]
416438 weight_loader = getattr (param , "weight_loader" ,
@@ -465,15 +487,29 @@ def _load_expert(n, p, name, shard_id, expert_id):
465487 ".block_sparse_moe.gate.weight" )
466488 _load (gate_name , p )
467489 else :
468- _load (n , p )
490+ loaded = False
491+ for param_name , weight_name , shard_id in stacked_params_mapping :
492+ if weight_name in n :
493+ _load_shard (n .replace (weight_name , param_name ),
494+ p ,
495+ shard_id = shard_id )
496+ loaded = True
497+ if not loaded :
498+ _load (n , p )
469499
470500 return loaded_params
471501
472502
473503class GraniteMoeHybridForCausalLM (nn .Module , HasInnerState , SupportsLoRA ,
474504 SupportsPP , IsHybrid , SupportsV0Only ,
475505 SupportsQuant ):
476- packed_modules_mapping = {}
506+ packed_modules_mapping = {
507+ "qkv_proj" : [
508+ "q_proj" ,
509+ "k_proj" ,
510+ "v_proj" ,
511+ ],
512+ }
477513 embedding_modules = {
478514 "embed_tokens" : "input_embeddings" ,
479515 "lm_head" : "output_embeddings" ,
0 commit comments