@@ -38,18 +38,6 @@ namespace {
3838 parallel_args, \
3939 options);
4040
41- #define MAKE_ROW_PARALLEL_LINEAR (LinearlImplClass ) \
42- std::make_shared<LinearlImplClass>(in_features, \
43- out_features, \
44- bias, \
45- input_is_parallelized, \
46- parallel_args, \
47- options);
48-
49- #define MAKE_COLUMN_PARALLEL_LINEAR (LinearlImplClass ) \
50- std::make_shared<LinearlImplClass>( \
51- in_features, out_features, bias, gather_output, parallel_args, options);
52-
5341std::shared_ptr<ParallelLinearImpl> create_column_parallel_qlinear_by_impl (
5442 int64_t in_features,
5543 int64_t out_features,
@@ -139,6 +127,7 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_qlinear(
139127 }
140128 // not supported quant method
141129 LOG (FATAL) << " Unsupported quant method: " << quant_args.quant_method ();
130+ return nullptr ;
142131}
143132
144133std::shared_ptr<ParallelLinearImpl> create_row_parallel_qlinear (
@@ -170,6 +159,7 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_qlinear(
170159 }
171160 // not supported quant method
172161 LOG (FATAL) << " Unsupported quant method: " << quant_args.quant_method ();
162+ return nullptr ;
173163}
174164
175165std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear (
@@ -179,7 +169,8 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
179169 bool gather_output,
180170 const QuantArgs& quant_args,
181171 const ParallelArgs& parallel_args,
182- const torch::TensorOptions& options) {
172+ const torch::TensorOptions& options,
173+ const std::string& prefix) {
183174 if (!quant_args.quant_method ().empty ()) {
184175 return create_column_parallel_qlinear (in_features,
185176 out_features,
@@ -189,7 +180,40 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
189180 parallel_args,
190181 options);
191182 }
192- return MAKE_COLUMN_PARALLEL_LINEAR (ColumnParallelLinearImpl);
183+ return std ::make_shared<ColumnParallelLinearImpl>(in_features,
184+ out_features,
185+ bias,
186+ gather_output,
187+ parallel_args,
188+ options,
189+ prefix);
190+ }
191+
192+ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear (
193+ int64_t in_features,
194+ const std::vector<int64_t >& out_features,
195+ const std::vector<std::string>& prefixes,
196+ bool bias,
197+ bool gather_output,
198+ const QuantArgs& quant_args,
199+ const ParallelArgs& parallel_args,
200+ const torch::TensorOptions& options) {
201+ // if (!quant_args.quant_method().empty()) {
202+ // return create_column_parallel_qlinear(in_features,
203+ // out_features,
204+ // bias,
205+ // gather_output,
206+ // quant_args,
207+ // parallel_args,
208+ // options);
209+ // }
210+ return std ::make_shared<FColumnParallelLinearImpl>(in_features,
211+ out_features,
212+ prefixes,
213+ bias,
214+ gather_output,
215+ parallel_args,
216+ options);
193217}
194218
195219std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear (
@@ -209,7 +233,13 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
209233 parallel_args,
210234 options);
211235 }
212- return MAKE_ROW_PARALLEL_LINEAR (RowParallelLinearImpl);
236+ return std ::make_shared<RowParallelLinearImpl>(in_features,
237+ out_features,
238+ bias,
239+ input_is_parallelized,
240+ parallel_args,
241+ options);
242+ ;
213243}
214244} // namespace
215245
@@ -221,9 +251,29 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
221251 bool gather_output,
222252 const QuantArgs& quant_args,
223253 const ParallelArgs& parallel_args,
224- const torch::TensorOptions& options)
254+ const torch::TensorOptions& options,
255+ const std::string& prefix)
256+ : ModuleHolder(create_column_parallel_linear(in_features,
257+ out_features,
258+ bias,
259+ gather_output,
260+ quant_args,
261+ parallel_args,
262+ options,
263+ prefix)) {}
264+
265+ ColumnParallelLinear::ColumnParallelLinear (
266+ int64_t in_features,
267+ const std::vector<int64_t >& out_features,
268+ const std::vector<std::string>& prefixes,
269+ bool bias,
270+ bool gather_output,
271+ const QuantArgs& quant_args,
272+ const ParallelArgs& parallel_args,
273+ const torch::TensorOptions& options)
225274 : ModuleHolder(create_column_parallel_linear(in_features,
226275 out_features,
276+ prefixes,
227277 bias,
228278 gather_output,
229279 quant_args,
@@ -242,7 +292,8 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
242292 gather_output,
243293 {}, /* quant_args*/
244294 parallel_args,
245- options)) {}
295+ options,
296+ " " )) {}
246297
247298// construct a rotary positional embedding.
248299// chose right implementation based on the args.
0 commit comments