Skip to content

Commit 4ffe8db

Browse files
committed
fix qwen and qwen2
1 parent 045658c commit 4ffe8db

File tree

5 files changed

+43
-175
lines changed

5 files changed

+43
-175
lines changed

src/layers/linear_impl.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
3535
torch::empty({out_features_per_partition, in_features}, options));
3636

3737
if (bias) {
38-
bias_ = register_parameter(
39-
"bias", torch::empty({out_features_per_partition}, options));
38+
bias_ = register_sharded_parameter(
39+
"bias",
40+
/*dim=*/0,
41+
rank,
42+
world_size,
43+
torch::empty({out_features_per_partition}, options));
4044
}
4145
}
4246

src/models/llama.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "layers/embedding.h"
1111
#include "layers/fused_linear.h"
1212
#include "layers/linear.h"
13-
#include "layers/linear_impl.h"
1413
#include "layers/normalization.h"
1514
#include "layers/qkv_linear.h"
1615
#include "memory/kv_cache.h"

src/models/models.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
// #include "mistral.h" // IWYU pragma: keep
1616
// #include "mpt.h" // IWYU pragma: keep
1717
// #include "phi.h" // IWYU pragma: keep
18-
// #include "qwen.h" // IWYU pragma: keep
19-
// #include "qwen2.h" // IWYU pragma: keep
18+
#include "qwen.h" // IWYU pragma: keep
19+
#include "qwen2.h" // IWYU pragma: keep

src/models/qwen.h

Lines changed: 18 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "layers/attention/attention.h"
1111
#include "layers/attention/handler.h"
1212
#include "layers/embedding.h"
13+
#include "layers/fused_linear.h"
1314
#include "layers/linear.h"
1415
#include "layers/normalization.h"
1516
#include "memory/kv_cache.h"
@@ -20,7 +21,7 @@
2021
#include "module/module_holder.h"
2122
#include "module/module_list.h"
2223
// QWen model compatible with huggingface weights
23-
// adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
24+
// Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
2425
namespace llm::hf {
2526

2627
class QWenMLPImpl : public Module {
@@ -38,14 +39,18 @@ class QWenMLPImpl : public Module {
3839
const int64_t intermediate_size = args.intermediate_size() / 2;
3940

4041
// register the weight parameter
41-
w1_w2_proj_ = register_module("gate_up_proj",
42-
ColumnParallelLinear(hidden_size,
43-
intermediate_size * 2,
44-
/*bias=*/false,
45-
/*gather_output=*/false,
46-
quant_args,
47-
parallel_args,
48-
options));
42+
gate_up_proj_ = register_module(
43+
"gate_up_proj",
44+
FusedColumnParallelLinear(
45+
hidden_size,
46+
std::vector<int64_t>{intermediate_size, intermediate_size},
47+
std::vector<std::string>{"w1.", "w2."},
48+
/*bias=*/false,
49+
/*gather_output=*/false,
50+
quant_args,
51+
parallel_args,
52+
options),
53+
/*selector=*/nullptr);
4954
c_proj_ = register_module("c_proj",
5055
RowParallelLinear(intermediate_size,
5156
hidden_size,
@@ -57,26 +62,13 @@ class QWenMLPImpl : public Module {
5762
}
5863

5964
torch::Tensor forward(torch::Tensor x) {
60-
auto gate_up_proj = w1_w2_proj_(x);
61-
auto chunks = gate_up_proj.chunk(/*chunks=*/2, /*dim=*/-1);
62-
return c_proj_(chunks[0] * act_(chunks[1]));
63-
}
64-
65-
// load the weight from the checkpoint
66-
void load_state_dict(const StateDict& state_dict) {
67-
// call each submodule's load_state_dict function
68-
w1_w2_proj_->load_state_dict(state_dict, {"w1.", "w2."});
69-
c_proj_->load_state_dict(state_dict.select("c_proj."));
70-
}
71-
72-
void verify_loaded_weights(const std::string& prefix) const {
73-
w1_w2_proj_->verify_loaded_weights(prefix + "[w1,w2].");
74-
c_proj_->verify_loaded_weights(prefix + "c_proj.");
65+
const auto gate_up = gate_up_proj_(x);
66+
return c_proj_(gate_up[0] * act_(gate_up[1]));
7567
}
7668

7769
private:
7870
// parameter members, must be registered
79-
ColumnParallelLinear w1_w2_proj_{nullptr};
71+
FusedColumnParallelLinear gate_up_proj_{nullptr};
8072
RowParallelLinear c_proj_{nullptr};
8173

8274
ActFunc act_{nullptr};
@@ -133,18 +125,6 @@ class QWenAttentionImpl : public Module {
133125
return c_proj_(output);
134126
}
135127

136-
// load the weight from the checkpoint
137-
void load_state_dict(const StateDict& state_dict) {
138-
// call each submodule's load_state_dict function
139-
c_attn_->load_state_dict(state_dict.select("c_attn."));
140-
c_proj_->load_state_dict(state_dict.select("c_proj."));
141-
}
142-
143-
void verify_loaded_weights(const std::string& prefix) const {
144-
c_attn_->verify_loaded_weights(prefix + "c_attn.");
145-
c_proj_->verify_loaded_weights(prefix + "c_proj.");
146-
}
147-
148128
private:
149129
// parameter members, must be registered
150130
ColumnParallelLinear c_attn_{nullptr};
@@ -183,22 +163,6 @@ class QWenBlockImpl : public Module {
183163
return h + mlp_(ln_2_(h));
184164
}
185165

186-
// load the weight from the checkpoint
187-
void load_state_dict(const StateDict& state_dict) {
188-
// call each submodule's load_state_dict function
189-
attn_->load_state_dict(state_dict.select("attn."));
190-
mlp_->load_state_dict(state_dict.select("mlp."));
191-
ln_1_->load_state_dict(state_dict.select("ln_1."));
192-
ln_2_->load_state_dict(state_dict.select("ln_2."));
193-
}
194-
195-
void verify_loaded_weights(const std::string& prefix) const {
196-
attn_->verify_loaded_weights(prefix + "attn.");
197-
mlp_->verify_loaded_weights(prefix + "mlp.");
198-
ln_1_->verify_loaded_weights(prefix + "ln_1.");
199-
ln_2_->verify_loaded_weights(prefix + "ln_2.");
200-
}
201-
202166
private:
203167
// parameter members, must be registered
204168
QWenAttention attn_{nullptr};
@@ -226,7 +190,7 @@ class QWenModelImpl : public Module {
226190
handler_ = AttentionHandler::create_handler_with_rope(
227191
args, /*interleaved=*/false, options);
228192

229-
blocks_ = register_module("layers", ModuleList());
193+
blocks_ = register_module("h", ModuleList());
230194
layers_.reserve(args.n_layers());
231195
for (int32_t i = 0; i < args.n_layers(); i++) {
232196
auto block =
@@ -254,26 +218,6 @@ class QWenModelImpl : public Module {
254218
return ln_f_(h);
255219
}
256220

257-
// load the weight from the checkpoint
258-
void load_state_dict(const StateDict& state_dict) {
259-
wte_->load_state_dict(state_dict.select("wte."));
260-
// call each layer's load_state_dict function
261-
for (int i = 0; i < layers_.size(); i++) {
262-
layers_[i]->load_state_dict(
263-
state_dict.select("h." + std::to_string(i) + "."));
264-
}
265-
ln_f_->load_state_dict(state_dict.select("ln_f."));
266-
}
267-
268-
void verify_loaded_weights(const std::string& prefix) const {
269-
wte_->verify_loaded_weights(prefix + "wte.");
270-
for (int i = 0; i < layers_.size(); i++) {
271-
layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) +
272-
".");
273-
}
274-
ln_f_->verify_loaded_weights(prefix + "ln_f.");
275-
}
276-
277221
private:
278222
// parameter members, must be registered
279223
ParallelEmbedding wte_{nullptr};
@@ -331,17 +275,6 @@ class QWenForCausalLMImpl : public Module {
331275
return lm_head_(h);
332276
}
333277

334-
// load the weight from the checkpoint
335-
void load_state_dict(const StateDict& state_dict) {
336-
transformer_->load_state_dict(state_dict.select("transformer."));
337-
lm_head_->load_state_dict(state_dict.select("lm_head."));
338-
}
339-
340-
void verify_loaded_weights() const {
341-
transformer_->verify_loaded_weights("transformer.");
342-
lm_head_->verify_loaded_weights("lm_head.");
343-
}
344-
345278
private:
346279
// parameter members, must be registered
347280
QWenModel transformer_{nullptr};

src/models/qwen2.h

Lines changed: 17 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ class QWen2MLPImpl : public Module {
4343
FusedColumnParallelLinear(
4444
hidden_size,
4545
std::vector<int64_t>{intermediate_size, intermediate_size},
46+
std::vector<std::string>{"gate_proj.", "up_proj."},
4647
/*bias=*/false,
4748
/*gather_output=*/false,
4849
quant_args,
4950
parallel_args,
50-
options));
51+
options),
52+
/*selector=*/nullptr);
5153
down_proj_ =
5254
register_module("down_proj",
5355
RowParallelLinear(intermediate_size,
@@ -64,18 +66,6 @@ class QWen2MLPImpl : public Module {
6466
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
6567
}
6668

67-
// load the weight from the checkpoint
68-
void load_state_dict(const StateDict& state_dict) {
69-
// call each submodule's load_state_dict function
70-
gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
71-
down_proj_->load_state_dict(state_dict.select("down_proj."));
72-
}
73-
74-
void verify_loaded_weights(const std::string& prefix) const {
75-
gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
76-
down_proj_->verify_loaded_weights(prefix + "down_proj.");
77-
}
78-
7969
private:
8070
// parameter members, must be registered
8171
FusedColumnParallelLinear gate_up_proj_{nullptr};
@@ -104,16 +94,20 @@ class QWen2AttentionImpl : public Module {
10494
std::max<int64_t>(1, n_kv_heads / world_size);
10595

10696
// register submodules
107-
qkv_proj_ = register_module("qkv_proj",
108-
QKVColumnParallelLinear(hidden_size,
109-
n_heads,
110-
n_kv_heads,
111-
head_dim,
112-
/*bias=*/true,
113-
/*gather_output=*/false,
114-
quant_args,
115-
parallel_args,
116-
options));
97+
qkv_proj_ = register_module(
98+
"qkv_proj",
99+
QKVColumnParallelLinear(
100+
hidden_size,
101+
n_heads,
102+
n_kv_heads,
103+
head_dim,
104+
std::vector<std::string>{"q_proj.", "k_proj.", "v_proj."},
105+
/*bias=*/true,
106+
/*gather_output=*/false,
107+
quant_args,
108+
parallel_args,
109+
options),
110+
/*selector=*/nullptr);
117111

118112
o_proj_ = register_module("o_proj",
119113
RowParallelLinear(hidden_size,
@@ -146,19 +140,6 @@ class QWen2AttentionImpl : public Module {
146140
return o_proj_(output);
147141
}
148142

149-
// load the weight from the checkpoint
150-
void load_state_dict(const StateDict& state_dict) {
151-
// call each submodule's load_state_dict function
152-
qkv_proj_->load_state_dict(
153-
state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."});
154-
o_proj_->load_state_dict(state_dict.select("o_proj."));
155-
}
156-
157-
void verify_loaded_weights(const std::string& prefix) const {
158-
qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
159-
o_proj_->verify_loaded_weights(prefix + "o_proj.");
160-
}
161-
162143
private:
163144
// parameter members, must be registered
164145
QKVColumnParallelLinear qkv_proj_{nullptr};
@@ -208,24 +189,6 @@ class QWen2DecoderLayerImpl : public Module {
208189
return hidden_states;
209190
}
210191

211-
// load the weight from the checkpoint
212-
void load_state_dict(const StateDict& state_dict) {
213-
// call each submodule's load_state_dict function
214-
self_attn_->load_state_dict(state_dict.select("self_attn."));
215-
mlp_->load_state_dict(state_dict.select("mlp."));
216-
input_layernorm_->load_state_dict(state_dict.select("input_layernorm."));
217-
post_attention_layernorm_->load_state_dict(
218-
state_dict.select("post_attention_layernorm."));
219-
}
220-
221-
void verify_loaded_weights(const std::string& prefix) const {
222-
self_attn_->verify_loaded_weights(prefix + "self_attn.");
223-
mlp_->verify_loaded_weights(prefix + "mlp.");
224-
input_layernorm_->verify_loaded_weights(prefix + "input_layernorm.");
225-
post_attention_layernorm_->verify_loaded_weights(
226-
prefix + "post_attention_layernorm.");
227-
}
228-
229192
private:
230193
// parameter members, must be registered
231194
QWen2Attention self_attn_{nullptr};
@@ -291,26 +254,6 @@ class QWen2ModelImpl : public Module {
291254
return norm_(h, residual);
292255
}
293256

294-
// load the weight from the checkpoint
295-
void load_state_dict(const StateDict& state_dict) {
296-
embed_tokens_->load_state_dict(state_dict.select("embed_tokens."));
297-
// call each layer's load_state_dict function
298-
for (int i = 0; i < layers_.size(); i++) {
299-
layers_[i]->load_state_dict(
300-
state_dict.select("layers." + std::to_string(i) + "."));
301-
}
302-
norm_->load_state_dict(state_dict.select("norm."));
303-
}
304-
305-
void verify_loaded_weights(const std::string& prefix) const {
306-
embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
307-
for (int i = 0; i < layers_.size(); i++) {
308-
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
309-
".");
310-
}
311-
norm_->verify_loaded_weights(prefix + "norm.");
312-
}
313-
314257
private:
315258
// parameter members, must be registered
316259
ParallelEmbedding embed_tokens_{nullptr};
@@ -368,17 +311,6 @@ class QWen2ForCausalLMImpl : public Module {
368311
return lm_head_(h);
369312
}
370313

371-
// load the weight from the checkpoint
372-
void load_state_dict(const StateDict& state_dict) {
373-
model_->load_state_dict(state_dict.select("model."));
374-
lm_head_->load_state_dict(state_dict.select("lm_head."));
375-
}
376-
377-
void verify_loaded_weights() const {
378-
model_->verify_loaded_weights("model.");
379-
lm_head_->verify_loaded_weights("lm_head.");
380-
}
381-
382314
private:
383315
// parameter members, must be registered
384316
QWen2Model model_{nullptr};

0 commit comments

Comments
 (0)