33#include < glog/logging.h>
44#include < torch/torch.h>
55
6+ #include " layers/linear/linear.h"
67#include " model_parallel/parallel_args.h"
78#include " parallel_linear.h"
89#include " quantization/quant_args.h"
@@ -19,38 +20,35 @@ MultiColumnParallelLinearImpl::MultiColumnParallelLinearImpl(
1920 const ParallelArgs& parallel_args,
2021 const torch::TensorOptions& options) {
2122 // check if the linear layers can be fused
22- fused_ = quant_args. can_be_fused () ;
23- if (fused_ ) {
23+ std::shared_ptr<MultiParallelLinearImpl> linear ;
24+ if (quant_args. can_be_fused () ) {
2425 // fused linear layer
25- fused_linear_ = register_module (" fused_linear" ,
26- FusedColumnParallelLinear (in_features,
27- out_features_vec,
28- prefixes,
29- bias,
30- gather_output,
31- parallel_args,
32- options),
33- /* selector=*/ nullptr );
26+ linear = register_module (" fused_linear" ,
27+ FusedColumnParallelLinear (in_features,
28+ out_features_vec,
29+ prefixes,
30+ bias,
31+ gather_output,
32+ parallel_args,
33+ options),
34+ /* selector=*/ nullptr );
3435 } else {
3536 // non-fused linear layers
36- grouped_linear_ =
37- register_module (" grouped_linear" ,
38- GroupedColumnParallelLinear (in_features,
39- out_features_vec,
40- prefixes,
41- bias,
42- gather_output,
43- parallel_args,
44- options),
45- /* selector=*/ nullptr );
37+ linear = register_module (" grouped_linear" ,
38+ GroupedColumnParallelLinear (in_features,
39+ out_features_vec,
40+ prefixes,
41+ bias,
42+ gather_output,
43+ parallel_args,
44+ options),
45+ /* selector=*/ nullptr );
4646 }
47+ linear_ = linear;
4748}
4849
4950std::vector<torch::Tensor> MultiColumnParallelLinearImpl::forward (
5051 torch::Tensor input) {
51- if (fused_) {
52- return fused_linear_ (input);
53- }
54- return grouped_linear_ (input);
52+ return linear_ (input);
5553}
5654} // namespace llm
0 commit comments