29
29
30
30
from vllm .attention import Attention , AttentionMetadata
31
31
from vllm .config import CacheConfig , LoRAConfig
32
- from vllm .distributed import (get_pp_group , get_pp_indices ,
33
- get_tensor_model_parallel_rank ,
32
+ from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
34
33
get_tensor_model_parallel_world_size )
35
34
from vllm .model_executor .layers .activation import SiluAndMul
36
35
from vllm .model_executor .layers .layernorm import RMSNorm
51
50
from vllm .utils import is_hip , print_warning_once
52
51
53
52
from .interfaces import SupportsLoRA
53
+ from .utils import is_pp_missing_parameter , make_layers
54
54
55
55
56
56
class LlamaMLP (nn .Module ):
@@ -262,20 +262,11 @@ def __init__(
262
262
config .hidden_size ,
263
263
org_num_embeddings = config .vocab_size ,
264
264
)
265
- self .start_layer , self .end_layer = get_pp_indices (
265
+ self .start_layer , self .end_layer , self . layers = make_layers (
266
266
config .num_hidden_layers ,
267
- get_pp_group ().rank_in_group ,
268
- get_pp_group ().world_size )
269
- self .layers = nn .ModuleList (
270
- [nn .Identity () for _ in range (self .start_layer )] + [
271
- LlamaDecoderLayer (config = config ,
272
- cache_config = cache_config ,
273
- quant_config = quant_config )
274
- for _ in range (self .start_layer , self .end_layer )
275
- ] + [
276
- nn .Identity ()
277
- for _ in range (self .end_layer , config .num_hidden_layers )
278
- ])
267
+ lambda : LlamaDecoderLayer (config = config ,
268
+ cache_config = cache_config ,
269
+ quant_config = quant_config ))
279
270
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
280
271
281
272
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
@@ -455,12 +446,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
455
446
# Skip loading extra bias for GPTQ models.
456
447
if name .endswith (".bias" ) and name not in params_dict :
457
448
continue
458
- try :
459
- param = params_dict [name ]
460
- weight_loader = param .weight_loader
461
- weight_loader (param , loaded_weight , shard_id )
462
- except KeyError :
463
- pass
449
+
450
+ if is_pp_missing_parameter (name , self ):
451
+ continue
452
+
453
+ param = params_dict [name ]
454
+ weight_loader = param .weight_loader
455
+ weight_loader (param , loaded_weight , shard_id )
456
+
464
457
break
465
458
else :
466
459
# Skip loading extra bias for GPTQ models.
@@ -479,13 +472,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
479
472
continue
480
473
else :
481
474
name = remapped_kv_scale_name
482
- try :
483
- param = params_dict [name ]
484
- weight_loader = getattr (param , "weight_loader" ,
485
- default_weight_loader )
486
- weight_loader (param , loaded_weight )
487
- except KeyError :
488
- pass
475
+
476
+ if is_pp_missing_parameter (name , self ):
477
+ continue
478
+
479
+ param = params_dict [name ]
480
+ weight_loader = getattr (param , "weight_loader" ,
481
+ default_weight_loader )
482
+ weight_loader (param , loaded_weight )
489
483
490
484
# If this function is called, it should always initialize KV cache scale
491
485
# factors (or else raise an exception). Thus, handled exceptions should
0 commit comments