Skip to content

Commit 458e74e

Browse files
authored
Support more parallel styles in Transformers backend TP (#22651)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 65abe11 commit 458e74e

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,25 @@ def replace_linear_class(
107107
raise ValueError(
108108
f"Unsupported parallel style type {type(style)}, expected str")
109109

110-
vllm_linear_cls = {
111-
"colwise": ColumnParallelLinear,
112-
"rowwise": RowParallelLinear,
113-
}.get(style, ReplicatedLinear)
110+
vllm_linear_cls, vllm_linear_kwargs = {
111+
"colwise": (ColumnParallelLinear, {}),
112+
"colwise_rep": (ColumnParallelLinear, {
113+
"gather_output": True
114+
}),
115+
"rowwise": (RowParallelLinear, {}),
116+
"rowwise_rep": (RowParallelLinear, {
117+
"input_is_parallel": False
118+
}),
119+
"replicate": (ReplicatedLinear, {}),
120+
}.get(style, (ReplicatedLinear, {}))
114121

115122
return vllm_linear_cls(
116123
input_size=linear.in_features,
117124
output_size=linear.out_features,
118125
bias=linear.bias is not None,
119126
quant_config=quant_config,
120127
return_bias=False,
128+
**vllm_linear_kwargs,
121129
)
122130

123131

@@ -506,7 +514,7 @@ def tensor_parallel(self):
506514
# Some weight loaders expect linear layers to inherit from vLLM's
507515
# LinearBase class, so we set a default style which causes any
508516
# unspecified linear layers to be replaced with ReplicatedLinear
509-
tp_plan[".*"] = "replicated"
517+
tp_plan[".*"] = "replicate"
510518

511519
def _tensor_parallel(module: nn.Module, prefix: str = ""):
512520
for child_name, child_module in module.named_children():

0 commit comments

Comments
 (0)