Skip to content

Commit 0ca92b5

Browse files
committed
refactor
1 parent 8a102cd commit 0ca92b5

File tree

2 files changed

+25
-38
lines changed

2 files changed

+25
-38
lines changed

src/layers/linear/multi_parallel_linear.cpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <glog/logging.h>
44
#include <torch/torch.h>
55

6+
#include "layers/linear/linear.h"
67
#include "model_parallel/parallel_args.h"
78
#include "parallel_linear.h"
89
#include "quantization/quant_args.h"
@@ -19,38 +20,35 @@ MultiColumnParallelLinearImpl::MultiColumnParallelLinearImpl(
1920
const ParallelArgs& parallel_args,
2021
const torch::TensorOptions& options) {
2122
// check if the linear layers can be fused
22-
fused_ = quant_args.can_be_fused();
23-
if (fused_) {
23+
std::shared_ptr<MultiParallelLinearImpl> linear;
24+
if (quant_args.can_be_fused()) {
2425
// fused linear layer
25-
fused_linear_ = register_module("fused_linear",
26-
FusedColumnParallelLinear(in_features,
27-
out_features_vec,
28-
prefixes,
29-
bias,
30-
gather_output,
31-
parallel_args,
32-
options),
33-
/*selector=*/nullptr);
26+
linear = register_module("fused_linear",
27+
FusedColumnParallelLinear(in_features,
28+
out_features_vec,
29+
prefixes,
30+
bias,
31+
gather_output,
32+
parallel_args,
33+
options),
34+
/*selector=*/nullptr);
3435
} else {
3536
// non-fused linear layers
36-
grouped_linear_ =
37-
register_module("grouped_linear",
38-
GroupedColumnParallelLinear(in_features,
39-
out_features_vec,
40-
prefixes,
41-
bias,
42-
gather_output,
43-
parallel_args,
44-
options),
45-
/*selector=*/nullptr);
37+
linear = register_module("grouped_linear",
38+
GroupedColumnParallelLinear(in_features,
39+
out_features_vec,
40+
prefixes,
41+
bias,
42+
gather_output,
43+
parallel_args,
44+
options),
45+
/*selector=*/nullptr);
4646
}
47+
linear_ = linear;
4748
}
4849

4950
std::vector<torch::Tensor> MultiColumnParallelLinearImpl::forward(
5051
torch::Tensor input) {
51-
if (fused_) {
52-
return fused_linear_(input);
53-
}
54-
return grouped_linear_(input);
52+
return linear_(input);
5553
}
5654
} // namespace llm

src/layers/linear/multi_parallel_linear.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
#include <glog/logging.h>
44
#include <torch/torch.h>
55

6-
// #include "linear.h"
6+
#include "linear.h"
77
#include "model_parallel/parallel_args.h"
88
#include "module/module.h"
99
#include "module/module_holder.h"
10-
#include "parallel_linear.h"
1110
#include "quantization/quant_args.h"
1211

1312
namespace llm {
@@ -25,18 +24,8 @@ class MultiColumnParallelLinearImpl : public Module {
2524

2625
std::vector<torch::Tensor> forward(torch::Tensor input);
2726

28-
// whether the linear layer is fused
29-
bool fused() const { return fused_; }
30-
3127
private:
32-
// non-fused linear layers
33-
GroupedColumnParallelLinear grouped_linear_{nullptr};
34-
35-
// fused linear layer
36-
FusedColumnParallelLinear fused_linear_{nullptr};
37-
38-
// whether the linear layer is fused
39-
bool fused_ = false;
28+
MultiParallelLinear linear_{nullptr};
4029
};
4130
LLM_MODULE(MultiColumnParallelLinear);
4231

0 commit comments

Comments
 (0)