@@ -505,30 +505,47 @@ def tensor_parallel(self):
505
505
Apply the model's tensor parallelization plan.
506
506
Currently only supports linear layers.
507
507
"""
508
- tp_plan = getattr (self .model .config , "base_model_tp_plan" , None ) or {}
508
+ # Look for tp plans in all of the PreTrainedModels found in self.model
509
+ is_pretrained_model = lambda m : isinstance (m , PreTrainedModel )
510
+ supports_tp_plan = lambda m : m .config .base_model_tp_plan is not None
511
+ pretrained_models = filter (is_pretrained_model , self .model .modules ())
512
+ models_with_tp_plan = filter (supports_tp_plan , pretrained_models )
509
513
510
- if not tp_plan and self .tp_size > 1 :
514
+ if not any ( models_with_tp_plan ) and self .tp_size > 1 :
511
515
raise ValueError (
512
516
f"{ type (self .model )} does not support tensor parallel yet!" )
513
517
514
- # Some weight loaders expect linear layers to inherit from vLLM's
515
- # LinearBase class, so we set a default style which causes any
516
- # unspecified linear layers to be replaced with ReplicatedLinear
517
- tp_plan [".*" ] = "replicate"
518
-
519
- def _tensor_parallel (module : nn .Module , prefix : str = "" ):
518
+ def _tensor_parallel (module : nn .Module ,
519
+ prefix : str = "" ,
520
+ tp_plan = None ):
521
+ tp_plan = tp_plan or {}
522
+
523
+ # If the current module is a PreTrainedModel, set the tp_plan for
524
+ # all of its children
525
+ if isinstance (module , PreTrainedModel ):
526
+ tp_plan = module .config .base_model_tp_plan or {}
527
+ tp_plan = {
528
+ maybe_prefix (prefix , k ): v
529
+ for k , v in tp_plan .items ()
530
+ }
531
+
532
+ # Some weight loaders expect linear layers to inherit from vLLM's
533
+ # LinearBase class, so we set a default style which causes any
534
+ # unspecified linear layers to be replaced with ReplicatedLinear
520
535
for child_name , child_module in module .named_children ():
521
536
qual_name = maybe_prefix (prefix , child_name )
522
- for pattern , style in tp_plan . items ( ):
523
- if re .match (pattern , qual_name ) and isinstance (
524
- child_module , nn . Linear ):
525
- new_module = replace_linear_class (
526
- child_module , style , self . quant_config )
527
- setattr ( module , child_name , new_module )
528
- log_replacement ( qual_name , child_module , new_module )
529
- break
537
+ if isinstance ( child_module , nn . Linear ):
538
+ generator = ( p for p in tp_plan if re .match (p , qual_name ))
539
+ pattern = next ( generator , None )
540
+ style = tp_plan . get ( pattern , "replicate" )
541
+ new_module = replace_linear_class ( child_module , style ,
542
+ self . quant_config )
543
+ setattr ( module , child_name , new_module )
544
+ log_replacement ( qual_name , child_module , new_module )
530
545
else :
531
- _tensor_parallel (child_module , prefix = qual_name )
546
+ _tensor_parallel (child_module ,
547
+ prefix = qual_name ,
548
+ tp_plan = tp_plan )
532
549
533
550
_tensor_parallel (self .model )
534
551
0 commit comments