Skip to content

Commit 36449ad

Browse files
committed
rename.
1 parent 31afc66 commit 36449ad

22 files changed

+309
-289
lines changed

src/layers/CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ cc_library(
66
linear
77
HDRS
88
linear.h
9-
qkv_linear.h
10-
linear_impl.h
11-
fused_linear.h
9+
qkv_parallel_linear.h
10+
parallel_linear.h
11+
multi_parallel_linear.h
1212
weight_utils.h
1313
SRCS
1414
linear.cpp
15-
qkv_linear.cpp
16-
linear_impl.cpp
17-
fused_linear.cpp
15+
qkv_parallel_linear.cpp
16+
parallel_linear.cpp
17+
multi_parallel_linear.cpp
1818
weight_utils.cpp
1919
DEPS
2020
:state_dict
@@ -74,7 +74,7 @@ cc_test(
7474
pos_embedding_test.cpp
7575
normalization_test.cpp
7676
linear_test.cpp
77-
qkv_linear_test.cpp
77+
qkv_parallel_linear_test.cpp
7878
DEPS
7979
:layers
8080
:state_dict

src/layers/fused_linear.cpp

Lines changed: 0 additions & 85 deletions
This file was deleted.

src/layers/linear.cpp

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <boost/algorithm/string.hpp>
77
#include <memory>
88

9-
#include "linear_impl.h"
9+
#include "parallel_linear.h"
1010
#include "quantization/qlinear_awq_impl.h"
1111
#include "quantization/qlinear_awq_marlin_impl.h"
1212
#include "quantization/qlinear_exllamav2_impl.h"
@@ -189,33 +189,6 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
189189
prefix);
190190
}
191191

192-
std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
193-
int64_t in_features,
194-
const std::vector<int64_t>& out_features,
195-
const std::vector<std::string>& prefixes,
196-
bool bias,
197-
bool gather_output,
198-
const QuantArgs& quant_args,
199-
const ParallelArgs& parallel_args,
200-
const torch::TensorOptions& options) {
201-
// if (!quant_args.quant_method().empty()) {
202-
// return create_column_parallel_qlinear(in_features,
203-
// out_features,
204-
// bias,
205-
// gather_output,
206-
// quant_args,
207-
// parallel_args,
208-
// options);
209-
// }
210-
return std ::make_shared<FColumnParallelLinearImpl>(in_features,
211-
out_features,
212-
prefixes,
213-
bias,
214-
gather_output,
215-
parallel_args,
216-
options);
217-
}
218-
219192
std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
220193
int64_t in_features,
221194
int64_t out_features,
@@ -239,8 +212,38 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
239212
input_is_parallelized,
240213
parallel_args,
241214
options);
242-
;
243215
}
216+
217+
// std::shared_ptr<MultiParallelLinearImpl> create_multi_column_parallel_linear(
218+
// int64_t in_features,
219+
// const std::vector<int64_t>& out_features,
220+
// const std::vector<std::string>& prefixes,
221+
// bool bias,
222+
// bool gather_output,
223+
// const QuantArgs& quant_args,
224+
// const ParallelArgs& parallel_args,
225+
// const torch::TensorOptions& options) {
226+
// // check if the linear layers can be fused
227+
// const bool fused = quant_args.can_be_fused();
228+
// std::shared_ptr<MultiParallelLinearImpl> impl;
229+
// if (fused) {
230+
// return std::make_shared<FusedColumnParallelLinearImpl>(in_features,
231+
// out_features,
232+
// prefixes,
233+
// bias,
234+
// gather_output,
235+
// parallel_args,
236+
// options);
237+
// }
238+
239+
// return std::make_shared<GroupedColumnParallelLinearImpl>(in_features,
240+
// out_features,
241+
// prefixes,
242+
// bias,
243+
// gather_output,
244+
// parallel_args,
245+
// options);
246+
// }
244247
} // namespace
245248

246249
// construct a ColumnParallelLinear.
@@ -262,24 +265,6 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
262265
options,
263266
prefix)) {}
264267

265-
ColumnParallelLinear::ColumnParallelLinear(
266-
int64_t in_features,
267-
const std::vector<int64_t>& out_features,
268-
const std::vector<std::string>& prefixes,
269-
bool bias,
270-
bool gather_output,
271-
const QuantArgs& quant_args,
272-
const ParallelArgs& parallel_args,
273-
const torch::TensorOptions& options)
274-
: ModuleHolder(create_column_parallel_linear(in_features,
275-
out_features,
276-
prefixes,
277-
bias,
278-
gather_output,
279-
quant_args,
280-
parallel_args,
281-
options)) {}
282-
283268
ColumnParallelLinear::ColumnParallelLinear(int64_t in_features,
284269
int64_t out_features,
285270
bool bias,

src/layers/linear.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ class ParallelLinearImpl : public Module {
3737
LOG(FATAL) << "not implemented";
3838
}
3939
};
40+
LLM_MODULE(ParallelLinear);
41+
42+
class MultiParallelLinearImpl : public Module {
43+
public:
44+
~MultiParallelLinearImpl() override = default;
45+
46+
virtual std::vector<torch::Tensor> forward(torch::Tensor input) = 0;
47+
};
48+
LLM_MODULE(MultiParallelLinear);
4049

4150
class ColumnParallelLinear : public ModuleHolder<ParallelLinearImpl> {
4251
public:
@@ -54,15 +63,6 @@ class ColumnParallelLinear : public ModuleHolder<ParallelLinearImpl> {
5463
const torch::TensorOptions& options,
5564
const std::string& prefix = "");
5665

57-
ColumnParallelLinear(int64_t in_features,
58-
const std::vector<int64_t>& out_features,
59-
const std::vector<std::string>& prefixes,
60-
bool bias,
61-
bool gather_output,
62-
const QuantArgs& quant_args,
63-
const ParallelArgs& parallel_args,
64-
const torch::TensorOptions& options);
65-
6666
ColumnParallelLinear(int64_t in_features,
6767
int64_t out_features,
6868
bool bias,

0 commit comments

Comments
 (0)