Skip to content

Commit 2972034

Browse files
committed
refactor: clean up legacy load_state_dict for linear layers
1 parent 94e04bc commit 2972034

File tree

9 files changed

+19
-38
lines changed

9 files changed

+19
-38
lines changed

src/layers/linear.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ class ParallelLinearImpl : public Module {
2626

2727
virtual void verify_loaded_weights(const std::string& prefix = "") const = 0;
2828

29-
// load state dict with a transform function
30-
virtual void load_state_dict(const StateDict& /*state_dict*/,
31-
TensorTransform /*transform_func*/) {
32-
LOG(FATAL) << "not implemented";
33-
}
34-
3529
// special load_state_dict for fused cases
3630
virtual void load_state_dict(const StateDict& /*state_dict*/,
3731
const std::vector<std::string>& /*prefixes*/) {

src/layers/linear_impl.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,15 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) {
5555

5656
// load the weight from the checkpoint
5757
void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
58-
// call load_state_dict with identity transform
59-
load_state_dict(state_dict,
60-
[](const torch::Tensor& tensor) { return tensor; });
61-
}
62-
63-
void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict,
64-
TensorTransform transform_func) {
65-
CHECK(transform_func != nullptr) << "transform_func must be provided";
6658
const auto rank = parallel_args_.rank();
6759
const auto world_size = parallel_args_.world_size();
6860

6961
// load sharded weights on dim 0
70-
LOAD_SHARDED_WEIGHT_WITH_TRANSFORM(weight, 0);
62+
LOAD_SHARDED_WEIGHT(weight, 0);
7163

7264
if (bias_.defined()) {
7365
// load sharded bias on dim 0
74-
LOAD_SHARDED_WEIGHT_WITH_TRANSFORM(bias, 0);
66+
LOAD_SHARDED_WEIGHT(bias, 0);
7567
}
7668
}
7769

src/layers/linear_impl.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl {
2626
// load the weight from the checkpoint
2727
void load_state_dict(const StateDict& state_dict) override;
2828

29-
// load state dict with a transform function
30-
void load_state_dict(const StateDict& state_dict,
31-
TensorTransform transform_func) override;
32-
3329
// special load_state_dict for fused cases
3430
void load_state_dict(const StateDict& state_dict,
3531
const std::vector<std::string>& prefixes) override;

src/layers/qkv_linear.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ class QKVColumnParallelLinearImpl : public Module {
2727
const ParallelArgs& parallel_args,
2828
const torch::TensorOptions& options);
2929

30-
std::vector<torch::Tensor> forward(torch::Tensor input) {
31-
return parallel_linear_->forward(input);
30+
// returns (query, key, value)
31+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> forward(
32+
torch::Tensor input) {
33+
const auto qkv = parallel_linear_->forward(input);
34+
return {qkv[0], qkv[1], qkv[2]};
3235
}
3336

3437
private:

src/layers/qkv_linear_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,19 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) {
6767

6868
// generate random input and compare with the output
6969
auto input = torch::randn({n_tokens, hidden_size}, options);
70-
auto qkv = linear.forward(input);
70+
const auto [q, k, v] = linear.forward(input);
7171

7272
const int64_t kv_shard_id =
7373
n_kv_heads >= n_shards ? shard_id : n_kv_heads * shard_id / n_shards;
7474

7575
auto query = input.matmul(query_chunks[shard_id].t());
76-
EXPECT_TRUE(torch::allclose(qkv[0], query, /*rtol=*/1e-5, /*atol=*/1e-5));
76+
EXPECT_TRUE(torch::allclose(q, query, /*rtol=*/1e-5, /*atol=*/1e-5));
7777

7878
auto key = input.matmul(key_chunks[kv_shard_id].t());
79-
EXPECT_TRUE(torch::allclose(qkv[1], key, /*rtol=*/1e-5, /*atol=*/1e-5));
79+
EXPECT_TRUE(torch::allclose(k, key, /*rtol=*/1e-5, /*atol=*/1e-5));
8080

8181
auto value = input.matmul(value_chunks[kv_shard_id].t());
82-
EXPECT_TRUE(torch::allclose(qkv[2], value, /*rtol=*/1e-5, /*atol=*/1e-5));
82+
EXPECT_TRUE(torch::allclose(v, value, /*rtol=*/1e-5, /*atol=*/1e-5));
8383
}
8484
}
8585

src/models/alibaba/qwen2.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,9 @@ class QWen2AttentionImpl : public Module {
133133
const InputParameters& input_params) {
134134
// (num_tokens, dim) x (dim, n_local_heads * head_dim)
135135
// => (num_tokens, n_local_heads * head_dim)
136-
const auto qkv = qkv_proj_(x);
136+
const auto [q, k, v] = qkv_proj_(x);
137137
// calculate attention, output: (num_tokens, n_local_heads * head_dim)
138-
const auto output =
139-
atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params);
138+
const auto output = atten_(q, k, v, positions, kv_cache, input_params);
140139
return o_proj_(output);
141140
}
142141

src/models/google/gemma.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,10 @@ class GemmaAttentionImpl : public Module {
128128
const InputParameters& input_params) {
129129
// (num_tokens, dim) x (dim, n_local_heads * head_dim)
130130
// => (num_tokens, n_local_heads * head_dim)
131-
const auto qkv = qkv_proj_(x);
131+
const auto [q, k, v] = qkv_proj_(x);
132132
// calculate attention,
133133
// output: (num_tokens, n_local_heads*head_dim)
134-
const auto output =
135-
atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params);
134+
const auto output = atten_(q, k, v, positions, kv_cache, input_params);
136135
return o_proj_(output);
137136
}
138137

src/models/google/gemma2.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,10 @@ class Gemma2AttentionImpl : public Module {
132132
const InputParameters& input_params) {
133133
// (num_tokens, dim) x (dim, n_local_heads * head_dim)
134134
// => (num_tokens, n_local_heads * head_dim)
135-
const auto qkv = qkv_proj_(x);
135+
const auto [q, k, v] = qkv_proj_(x);
136136
// calculate attention,
137137
// output: (num_tokens, n_local_heads*head_dim)
138-
const auto output =
139-
atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params);
138+
const auto output = atten_(q, k, v, positions, kv_cache, input_params);
140139
return o_proj_(output);
141140
}
142141

src/models/meta/llama.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ class LlamaAttentionImpl : public Module {
127127
const InputParameters& input_params) {
128128
// (num_tokens, dim) x (dim, n_local_heads * head_dim)
129129
// => (num_tokens, n_local_heads * head_dim)
130-
const auto qkv = qkv_proj_(x);
130+
const auto [q, k, v] = qkv_proj_(x);
131131
// calculate attention, output: (num_tokens, n_local_heads * head_dim)
132-
const auto output =
133-
atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params);
132+
const auto output = atten_(q, k, v, positions, kv_cache, input_params);
134133
return o_proj_(output);
135134
}
136135

0 commit comments

Comments
 (0)