Skip to content

Commit d0a6301

Browse files
authored
Fix Transformers backend tensor parallel for multimodal models (#22673)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 45c3936 commit d0a6301

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -505,30 +505,47 @@ def tensor_parallel(self):
505505
Apply the model's tensor parallelization plan.
506506
Currently only supports linear layers.
507507
"""
508-
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
508+
# Look for tp plans in all of the PreTrainedModels found in self.model
509+
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
510+
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
511+
pretrained_models = filter(is_pretrained_model, self.model.modules())
512+
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
509513

510-
if not tp_plan and self.tp_size > 1:
514+
if not any(models_with_tp_plan) and self.tp_size > 1:
511515
raise ValueError(
512516
f"{type(self.model)} does not support tensor parallel yet!")
513517

514-
# Some weight loaders expect linear layers to inherit from vLLM's
515-
# LinearBase class, so we set a default style which causes any
516-
# unspecified linear layers to be replaced with ReplicatedLinear
517-
tp_plan[".*"] = "replicate"
518-
519-
def _tensor_parallel(module: nn.Module, prefix: str = ""):
518+
def _tensor_parallel(module: nn.Module,
519+
prefix: str = "",
520+
tp_plan=None):
521+
tp_plan = tp_plan or {}
522+
523+
# If the current module is a PreTrainedModel, set the tp_plan for
524+
# all of its children
525+
if isinstance(module, PreTrainedModel):
526+
tp_plan = module.config.base_model_tp_plan or {}
527+
tp_plan = {
528+
maybe_prefix(prefix, k): v
529+
for k, v in tp_plan.items()
530+
}
531+
532+
# Some weight loaders expect linear layers to inherit from vLLM's
533+
# LinearBase class, so we set a default style which causes any
534+
# unspecified linear layers to be replaced with ReplicatedLinear
520535
for child_name, child_module in module.named_children():
521536
qual_name = maybe_prefix(prefix, child_name)
522-
for pattern, style in tp_plan.items():
523-
if re.match(pattern, qual_name) and isinstance(
524-
child_module, nn.Linear):
525-
new_module = replace_linear_class(
526-
child_module, style, self.quant_config)
527-
setattr(module, child_name, new_module)
528-
log_replacement(qual_name, child_module, new_module)
529-
break
537+
if isinstance(child_module, nn.Linear):
538+
generator = (p for p in tp_plan if re.match(p, qual_name))
539+
pattern = next(generator, None)
540+
style = tp_plan.get(pattern, "replicate")
541+
new_module = replace_linear_class(child_module, style,
542+
self.quant_config)
543+
setattr(module, child_name, new_module)
544+
log_replacement(qual_name, child_module, new_module)
530545
else:
531-
_tensor_parallel(child_module, prefix=qual_name)
546+
_tensor_parallel(child_module,
547+
prefix=qual_name,
548+
tp_plan=tp_plan)
532549

533550
_tensor_parallel(self.model)
534551

0 commit comments

Comments
 (0)