Skip to content

Commit 31afc66

Browse files
committed
clean up load_state_dict for linear layers
1 parent b4cc8f2 commit 31afc66

File tree

8 files changed

+230
-155
lines changed

8 files changed

+230
-155
lines changed

src/layers/fused_linear.cpp

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,21 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl(
1919
const QuantArgs& quant_args,
2020
const ParallelArgs& parallel_args,
2121
const torch::TensorOptions& options) {
22-
prefixes_ = prefixes;
2322
// check if the linear layers can be fused
2423
fused_ = quant_args.can_be_fused();
2524
if (fused_) {
2625
// fused linear layer
27-
const int64_t out_features = std::accumulate(
28-
out_features_vec.begin(), out_features_vec.end(), int64_t(0));
29-
fused_linear_ = ColumnParallelLinear(in_features,
30-
out_features,
31-
bias,
32-
gather_output,
33-
quant_args,
34-
parallel_args,
35-
options);
26+
fused_linear_ = register_module("fused_linear",
27+
ColumnParallelLinear(in_features,
28+
out_features_vec,
29+
prefixes,
30+
bias,
31+
gather_output,
32+
quant_args,
33+
parallel_args,
34+
options),
35+
/*selector=*/nullptr);
36+
// TODO: clean up following code for calculating split sizes
3637
// calculate split sizes
3738
split_sizes_.reserve(out_features_vec.size());
3839
const auto world_size = parallel_args.world_size();
@@ -45,14 +46,22 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl(
4546
} else {
4647
// non-fused linear layers
4748
parallel_linears_.reserve(out_features_vec.size());
48-
for (const auto& out_features : out_features_vec) {
49-
parallel_linears_.emplace_back(in_features,
50-
out_features,
51-
bias,
52-
gather_output,
53-
quant_args,
54-
parallel_args,
55-
options);
49+
for (size_t i = 0; i < out_features_vec.size(); ++i) {
50+
const auto& prefix = prefixes[i];
51+
const auto out_features = out_features_vec[i];
52+
53+
const auto linear = register_module("linear",
54+
ColumnParallelLinear(in_features,
55+
out_features,
56+
bias,
57+
gather_output,
58+
quant_args,
59+
parallel_args,
60+
options,
61+
prefix),
62+
/*selector=*/nullptr);
63+
64+
parallel_linears_.emplace_back(linear);
5665
}
5766
}
5867
}
@@ -73,30 +82,4 @@ std::vector<torch::Tensor> FusedColumnParallelLinearImpl::forward(
7382
}
7483
return outputs;
7584
}
76-
77-
size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict,
78-
const std::string&) {
79-
if (fused_) {
80-
fused_linear_->load_state_dict(state_dict, prefixes_);
81-
} else {
82-
CHECK_EQ(parallel_linears_.size(), prefixes_.size());
83-
for (size_t i = 0; i < parallel_linears_.size(); ++i) {
84-
parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i]));
85-
}
86-
}
87-
return 0;
88-
}
89-
90-
bool FusedColumnParallelLinearImpl::verify(
91-
const std::string& name_prefix) const {
92-
if (fused_) {
93-
fused_linear_->verify_loaded_weights(name_prefix);
94-
} else {
95-
for (const auto& parallel_linear : parallel_linears_) {
96-
parallel_linear->verify_loaded_weights(name_prefix);
97-
}
98-
}
99-
return true;
100-
}
101-
10285
} // namespace llm

src/layers/fused_linear.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@ class FusedColumnParallelLinearImpl : public Module {
2525

2626
std::vector<torch::Tensor> forward(torch::Tensor input);
2727

28-
// load weights from the checkpoint, override this method if necessary
29-
// returns the number of loaded parameters
30-
size_t load(const StateDict& state_dict,
31-
const std::string& name_prefix = std::string()) override;
32-
33-
// verify whether the weights are loaded, override this method if necessary
34-
bool verify(const std::string& name_prefix = std::string()) const override;
35-
3628
// whether the linear layer is fused
3729
bool fused() const { return fused_; }
3830

@@ -43,11 +35,9 @@ class FusedColumnParallelLinearImpl : public Module {
4335
// fused linear layer
4436
ColumnParallelLinear fused_linear_{nullptr};
4537

46-
// sizes for each split
38+
// size for each split
4739
std::vector<int64_t> split_sizes_;
4840

49-
std::vector<std::string> prefixes_;
50-
5141
// whether the linear layer is fused
5242
bool fused_ = false;
5343
};

src/layers/linear.cpp

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,6 @@ namespace {
3838
parallel_args, \
3939
options);
4040

41-
#define MAKE_ROW_PARALLEL_LINEAR(LinearlImplClass) \
42-
std::make_shared<LinearlImplClass>(in_features, \
43-
out_features, \
44-
bias, \
45-
input_is_parallelized, \
46-
parallel_args, \
47-
options);
48-
49-
#define MAKE_COLUMN_PARALLEL_LINEAR(LinearlImplClass) \
50-
std::make_shared<LinearlImplClass>( \
51-
in_features, out_features, bias, gather_output, parallel_args, options);
52-
5341
std::shared_ptr<ParallelLinearImpl> create_column_parallel_qlinear_by_impl(
5442
int64_t in_features,
5543
int64_t out_features,
@@ -139,6 +127,7 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_qlinear(
139127
}
140128
// not supported quant method
141129
LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method();
130+
return nullptr;
142131
}
143132

144133
std::shared_ptr<ParallelLinearImpl> create_row_parallel_qlinear(
@@ -170,6 +159,7 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_qlinear(
170159
}
171160
// not supported quant method
172161
LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method();
162+
return nullptr;
173163
}
174164

175165
std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
@@ -179,7 +169,8 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
179169
bool gather_output,
180170
const QuantArgs& quant_args,
181171
const ParallelArgs& parallel_args,
182-
const torch::TensorOptions& options) {
172+
const torch::TensorOptions& options,
173+
const std::string& prefix) {
183174
if (!quant_args.quant_method().empty()) {
184175
return create_column_parallel_qlinear(in_features,
185176
out_features,
@@ -189,7 +180,40 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
189180
parallel_args,
190181
options);
191182
}
192-
return MAKE_COLUMN_PARALLEL_LINEAR(ColumnParallelLinearImpl);
183+
return std ::make_shared<ColumnParallelLinearImpl>(in_features,
184+
out_features,
185+
bias,
186+
gather_output,
187+
parallel_args,
188+
options,
189+
prefix);
190+
}
191+
192+
std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
193+
int64_t in_features,
194+
const std::vector<int64_t>& out_features,
195+
const std::vector<std::string>& prefixes,
196+
bool bias,
197+
bool gather_output,
198+
const QuantArgs& quant_args,
199+
const ParallelArgs& parallel_args,
200+
const torch::TensorOptions& options) {
201+
// if (!quant_args.quant_method().empty()) {
202+
// return create_column_parallel_qlinear(in_features,
203+
// out_features,
204+
// bias,
205+
// gather_output,
206+
// quant_args,
207+
// parallel_args,
208+
// options);
209+
// }
210+
return std ::make_shared<FColumnParallelLinearImpl>(in_features,
211+
out_features,
212+
prefixes,
213+
bias,
214+
gather_output,
215+
parallel_args,
216+
options);
193217
}
194218

195219
std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
@@ -209,7 +233,13 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
209233
parallel_args,
210234
options);
211235
}
212-
return MAKE_ROW_PARALLEL_LINEAR(RowParallelLinearImpl);
236+
return std ::make_shared<RowParallelLinearImpl>(in_features,
237+
out_features,
238+
bias,
239+
input_is_parallelized,
240+
parallel_args,
241+
options);
242+
;
213243
}
214244
} // namespace
215245

@@ -221,9 +251,29 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
221251
bool gather_output,
222252
const QuantArgs& quant_args,
223253
const ParallelArgs& parallel_args,
224-
const torch::TensorOptions& options)
254+
const torch::TensorOptions& options,
255+
const std::string& prefix)
256+
: ModuleHolder(create_column_parallel_linear(in_features,
257+
out_features,
258+
bias,
259+
gather_output,
260+
quant_args,
261+
parallel_args,
262+
options,
263+
prefix)) {}
264+
265+
ColumnParallelLinear::ColumnParallelLinear(
266+
int64_t in_features,
267+
const std::vector<int64_t>& out_features,
268+
const std::vector<std::string>& prefixes,
269+
bool bias,
270+
bool gather_output,
271+
const QuantArgs& quant_args,
272+
const ParallelArgs& parallel_args,
273+
const torch::TensorOptions& options)
225274
: ModuleHolder(create_column_parallel_linear(in_features,
226275
out_features,
276+
prefixes,
227277
bias,
228278
gather_output,
229279
quant_args,
@@ -242,7 +292,8 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
242292
gather_output,
243293
{}, /*quant_args*/
244294
parallel_args,
245-
options)) {}
295+
options,
296+
"")) {}
246297

247298
// construct a rotary positional embedding.
248299
// chose right implementation based on the args.

src/layers/linear.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@ class ParallelLinearImpl : public Module {
2222

2323
virtual torch::Tensor forward(torch::Tensor input) = 0;
2424

25-
virtual void load_state_dict(const StateDict& state_dict) = 0;
25+
// TODO: clean up the interface of load_state_dict
26+
virtual void load_state_dict(const StateDict& state_dict) {
27+
LOG(FATAL) << "not implemented";
28+
}
2629

27-
virtual void verify_loaded_weights(const std::string& prefix = "") const = 0;
30+
virtual void verify_loaded_weights(const std::string& prefix = "") const {
31+
LOG(FATAL) << "not implemented";
32+
}
2833

2934
// special load_state_dict for fused cases
3035
virtual void load_state_dict(const StateDict& /*state_dict*/,
@@ -46,6 +51,16 @@ class ColumnParallelLinear : public ModuleHolder<ParallelLinearImpl> {
4651
bool gather_output,
4752
const QuantArgs& quant_args,
4853
const ParallelArgs& parallel_args,
54+
const torch::TensorOptions& options,
55+
const std::string& prefix = "");
56+
57+
ColumnParallelLinear(int64_t in_features,
58+
const std::vector<int64_t>& out_features,
59+
const std::vector<std::string>& prefixes,
60+
bool bias,
61+
bool gather_output,
62+
const QuantArgs& quant_args,
63+
const ParallelArgs& parallel_args,
4964
const torch::TensorOptions& options);
5065

5166
ColumnParallelLinear(int64_t in_features,

0 commit comments

Comments
 (0)