66#include < boost/algorithm/string.hpp>
77#include < memory>
88
9- #include " linear_impl .h"
9+ #include " parallel_linear .h"
1010#include " quantization/qlinear_awq_impl.h"
1111#include " quantization/qlinear_awq_marlin_impl.h"
1212#include " quantization/qlinear_exllamav2_impl.h"
@@ -189,33 +189,6 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
189189 prefix);
190190}
191191
192- std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear (
193- int64_t in_features,
194- const std::vector<int64_t >& out_features,
195- const std::vector<std::string>& prefixes,
196- bool bias,
197- bool gather_output,
198- const QuantArgs& quant_args,
199- const ParallelArgs& parallel_args,
200- const torch::TensorOptions& options) {
201- // if (!quant_args.quant_method().empty()) {
202- // return create_column_parallel_qlinear(in_features,
203- // out_features,
204- // bias,
205- // gather_output,
206- // quant_args,
207- // parallel_args,
208- // options);
209- // }
210- return std ::make_shared<FColumnParallelLinearImpl>(in_features,
211- out_features,
212- prefixes,
213- bias,
214- gather_output,
215- parallel_args,
216- options);
217- }
218-
219192std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear (
220193 int64_t in_features,
221194 int64_t out_features,
@@ -239,8 +212,38 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
239212 input_is_parallelized,
240213 parallel_args,
241214 options);
242- ;
243215}
216+
217+ // std::shared_ptr<MultiParallelLinearImpl> create_multi_column_parallel_linear(
218+ // int64_t in_features,
219+ // const std::vector<int64_t>& out_features,
220+ // const std::vector<std::string>& prefixes,
221+ // bool bias,
222+ // bool gather_output,
223+ // const QuantArgs& quant_args,
224+ // const ParallelArgs& parallel_args,
225+ // const torch::TensorOptions& options) {
226+ // // check if the linear layers can be fused
227+ // const bool fused = quant_args.can_be_fused();
228+ // std::shared_ptr<MultiParallelLinearImpl> impl;
229+ // if (fused) {
230+ // return std::make_shared<FusedColumnParallelLinearImpl>(in_features,
231+ // out_features,
232+ // prefixes,
233+ // bias,
234+ // gather_output,
235+ // parallel_args,
236+ // options);
237+ // }
238+
239+ // return std::make_shared<GroupedColumnParallelLinearImpl>(in_features,
240+ // out_features,
241+ // prefixes,
242+ // bias,
243+ // gather_output,
244+ // parallel_args,
245+ // options);
246+ // }
244247} // namespace
245248
246249// construct a ColumnParallelLinear.
@@ -262,24 +265,6 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
262265 options,
263266 prefix)) {}
264267
265- ColumnParallelLinear::ColumnParallelLinear (
266- int64_t in_features,
267- const std::vector<int64_t >& out_features,
268- const std::vector<std::string>& prefixes,
269- bool bias,
270- bool gather_output,
271- const QuantArgs& quant_args,
272- const ParallelArgs& parallel_args,
273- const torch::TensorOptions& options)
274- : ModuleHolder(create_column_parallel_linear(in_features,
275- out_features,
276- prefixes,
277- bias,
278- gather_output,
279- quant_args,
280- parallel_args,
281- options)) {}
282-
283268ColumnParallelLinear::ColumnParallelLinear (int64_t in_features,
284269 int64_t out_features,
285270 bool bias,
0 commit comments