@@ -107,17 +107,25 @@ def replace_linear_class(
107
107
raise ValueError (
108
108
f"Unsupported parallel style type { type (style )} , expected str" )
109
109
110
- vllm_linear_cls = {
111
- "colwise" : ColumnParallelLinear ,
112
- "rowwise" : RowParallelLinear ,
113
- }.get (style , ReplicatedLinear )
110
+ vllm_linear_cls , vllm_linear_kwargs = {
111
+ "colwise" : (ColumnParallelLinear , {}),
112
+ "colwise_rep" : (ColumnParallelLinear , {
113
+ "gather_output" : True
114
+ }),
115
+ "rowwise" : (RowParallelLinear , {}),
116
+ "rowwise_rep" : (RowParallelLinear , {
117
+ "input_is_parallel" : False
118
+ }),
119
+ "replicate" : (ReplicatedLinear , {}),
120
+ }.get (style , (ReplicatedLinear , {}))
114
121
115
122
return vllm_linear_cls (
116
123
input_size = linear .in_features ,
117
124
output_size = linear .out_features ,
118
125
bias = linear .bias is not None ,
119
126
quant_config = quant_config ,
120
127
return_bias = False ,
128
+ ** vllm_linear_kwargs ,
121
129
)
122
130
123
131
@@ -506,7 +514,7 @@ def tensor_parallel(self):
506
514
# Some weight loaders expect linear layers to inherit from vLLM's
507
515
# LinearBase class, so we set a default style which causes any
508
516
# unspecified linear layers to be replaced with ReplicatedLinear
509
- tp_plan [".*" ] = "replicated "
517
+ tp_plan [".*" ] = "replicate "
510
518
511
519
def _tensor_parallel (module : nn .Module , prefix : str = "" ):
512
520
for child_name , child_module in module .named_children ():
0 commit comments