Skip to content

Commit 045658c

Browse files
committed
fix llama
1 parent f3a9e81 commit 045658c

File tree

7 files changed

+32
-52
lines changed

7 files changed

+32
-52
lines changed

src/layers/qkv_linear.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl(
1010
int64_t n_heads,
1111
int64_t n_kv_heads,
1212
int64_t head_dim,
13+
const std::vector<std::string>& prefixes,
1314
bool bias,
1415
bool gather_output,
15-
const std::vector<std::string>& prefixes,
1616
const QuantArgs& quant_args,
1717
const ParallelArgs& parallel_args,
1818
const torch::TensorOptions& options) {

src/layers/qkv_linear.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class QKVColumnParallelLinearImpl : public Module {
2020
int64_t n_heads,
2121
int64_t n_kv_heads,
2222
int64_t head_dim,
23+
const std::vector<std::string>& prefixes,
2324
bool bias,
2425
bool gather_output,
25-
const std::vector<std::string>& prefixes,
2626
const QuantArgs& quant_args,
2727
const ParallelArgs& parallel_args,
2828
const torch::TensorOptions& options);

src/layers/qkv_linear_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) {
5656
n_heads,
5757
n_kv_heads,
5858
head_dim,
59+
std::vector<std::string>{"query.", "key.", "value."},
5960
/*bias=*/false,
6061
/*gather_output=*/false,
61-
std::vector<std::string>{"query.", "key.", "value."},
6262
quant_args,
6363
parallel_args,
6464
options);

src/models/gemma.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ class GemmaAttentionImpl : public Module {
100100
n_heads,
101101
n_kv_heads,
102102
head_dim,
103+
std::vector<std::string>{"q_proj.", "k_proj.", "v_proj."},
103104
/*bias=*/false,
104105
/*gather_output=*/false,
105-
std::vector<std::string>{"q_proj.", "k_proj.", "v_proj."},
106106
quant_args,
107107
parallel_args,
108108
options),

src/models/gemma2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ class Gemma2AttentionImpl : public Module {
100100
n_heads,
101101
n_kv_heads,
102102
head_dim,
103+
std::vector<std::string>{"q_proj.", "k_proj.", "v_proj."},
103104
args.attn_bias(),
104105
/*gather_output=*/false,
105-
std::vector<std::string>{"q_proj.", "k_proj.", "v_proj."},
106106
quant_args,
107107
parallel_args,
108108
options),

src/models/llama.h

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,18 @@ class LlamaMLPImpl : public Module {
3636
const int64_t intermediate_size = args.intermediate_size();
3737

3838
// register the weight parameter
39-
// gate_up_proj_ = register_module(
40-
// "gate_up_proj",
41-
// FusedColumnParallelLinear(
42-
// hidden_size,
43-
// std::vector<int64_t>{intermediate_size, intermediate_size},
44-
// /*bias=*/false,
45-
// /*gather_output=*/false,
46-
// quant_args,
47-
// parallel_args,
48-
// options));
39+
gate_up_proj_ = register_module(
40+
"gate_up_proj",
41+
FusedColumnParallelLinear(
42+
hidden_size,
43+
std::vector<int64_t>{intermediate_size, intermediate_size},
44+
std::vector<std::string>{"gate_proj.", "up_proj."},
45+
/*bias=*/false,
46+
/*gather_output=*/false,
47+
quant_args,
48+
parallel_args,
49+
options),
50+
/*selector=*/nullptr);
4951

5052
down_proj_ =
5153
register_module("down_proj",
@@ -63,18 +65,6 @@ class LlamaMLPImpl : public Module {
6365
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
6466
}
6567

66-
// // load the weight from the checkpoint
67-
// void load_state_dict(const StateDict& state_dict) {
68-
// // call each submodule's load_state_dict function
69-
// gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
70-
// down_proj_->load_state_dict(state_dict.select("down_proj."));
71-
// }
72-
73-
// void verify_loaded_weights(const std::string& prefix) const {
74-
// gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
75-
// down_proj_->verify_loaded_weights(prefix + "down_proj.");
76-
// }
77-
7868
private:
7969
// parameter members, must be registered
8070
FusedColumnParallelLinear gate_up_proj_{nullptr};
@@ -102,16 +92,20 @@ class LlamaAttentionImpl : public Module {
10292
std::max<int64_t>(1, n_kv_heads / world_size);
10393

10494
// register submodules
105-
// qkv_proj_ = register_module("qkv_proj",
106-
// QKVColumnParallelLinear(hidden_size,
107-
// n_heads,
108-
// n_kv_heads,
109-
// head_dim,
110-
// /*bias=*/false,
111-
// /*gather_output=*/false,
112-
// quant_args,
113-
// parallel_args,
114-
// options));
95+
qkv_proj_ = register_module(
96+
"qkv_proj",
97+
QKVColumnParallelLinear(
98+
hidden_size,
99+
n_heads,
100+
n_kv_heads,
101+
head_dim,
102+
std::vector<std::string>{"q_proj.", "k_proj.", "v_proj."},
103+
/*bias=*/false,
104+
/*gather_output=*/false,
105+
quant_args,
106+
parallel_args,
107+
options),
108+
/*selector=*/nullptr);
115109

116110
o_proj_ = register_module("o_proj",
117111
RowParallelLinear(hidden_size,
@@ -141,20 +135,6 @@ class LlamaAttentionImpl : public Module {
141135
return o_proj_(output);
142136
}
143137

144-
// // load the weight from the checkpoint
145-
// void load_state_dict(const StateDict& state_dict) {
146-
// // call each submodule's load_state_dict function
147-
// qkv_proj_->load_state_dict(
148-
// state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.",
149-
// "v_proj."});
150-
// o_proj_->load_state_dict(state_dict.select("o_proj."));
151-
// }
152-
153-
// void verify_loaded_weights(const std::string& prefix) const {
154-
// qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
155-
// o_proj_->verify_loaded_weights(prefix + "o_proj.");
156-
// }
157-
158138
private:
159139
// parameter members, must be registered
160140
QKVColumnParallelLinear qkv_proj_{nullptr};

src/models/models.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// #include "gpt_j.h" // IWYU pragma: keep
1212
// #include "gpt_neox.h" // IWYU pragma: keep
1313
// #include "internlm.h" // IWYU pragma: keep
14-
// #include "llama.h" // IWYU pragma: keep
14+
#include "llama.h" // IWYU pragma: keep
1515
// #include "mistral.h" // IWYU pragma: keep
1616
// #include "mpt.h" // IWYU pragma: keep
1717
// #include "phi.h" // IWYU pragma: keep

0 commit comments

Comments
 (0)