1010#include " layers/attention/attention.h"
1111#include " layers/attention/handler.h"
1212#include " layers/embedding.h"
13+ #include " layers/fused_linear.h"
1314#include " layers/linear.h"
1415#include " layers/normalization.h"
1516#include " memory/kv_cache.h"
2021#include " module/module_holder.h"
2122#include " module/module_list.h"
2223// QWen model compatible with huggingface weights
23- // adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
24+ // Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
2425namespace llm ::hf {
2526
2627class QWenMLPImpl : public Module {
@@ -38,14 +39,18 @@ class QWenMLPImpl : public Module {
3839 const int64_t intermediate_size = args.intermediate_size () / 2 ;
3940
4041 // register the weight parameter
41- w1_w2_proj_ = register_module (" gate_up_proj" ,
42- ColumnParallelLinear (hidden_size,
43- intermediate_size * 2 ,
44- /* bias=*/ false ,
45- /* gather_output=*/ false ,
46- quant_args,
47- parallel_args,
48- options));
42+ gate_up_proj_ = register_module (
43+ " gate_up_proj" ,
44+ FusedColumnParallelLinear (
45+ hidden_size,
46+ std::vector<int64_t >{intermediate_size, intermediate_size},
47+ std::vector<std::string>{" w1." , " w2." },
48+ /* bias=*/ false ,
49+ /* gather_output=*/ false ,
50+ quant_args,
51+ parallel_args,
52+ options),
53+ /* selector=*/ nullptr );
4954 c_proj_ = register_module (" c_proj" ,
5055 RowParallelLinear (intermediate_size,
5156 hidden_size,
@@ -57,26 +62,13 @@ class QWenMLPImpl : public Module {
5762 }
5863
5964 torch::Tensor forward (torch::Tensor x) {
60- auto gate_up_proj = w1_w2_proj_ (x);
61- auto chunks = gate_up_proj.chunk (/* chunks=*/ 2 , /* dim=*/ -1 );
62- return c_proj_ (chunks[0 ] * act_ (chunks[1 ]));
63- }
64-
65- // load the weight from the checkpoint
66- void load_state_dict (const StateDict& state_dict) {
67- // call each submodule's load_state_dict function
68- w1_w2_proj_->load_state_dict (state_dict, {" w1." , " w2." });
69- c_proj_->load_state_dict (state_dict.select (" c_proj." ));
70- }
71-
72- void verify_loaded_weights (const std::string& prefix) const {
73- w1_w2_proj_->verify_loaded_weights (prefix + " [w1,w2]." );
74- c_proj_->verify_loaded_weights (prefix + " c_proj." );
65+ const auto gate_up = gate_up_proj_ (x);
66+ return c_proj_ (gate_up[0 ] * act_ (gate_up[1 ]));
7567 }
7668
7769 private:
7870 // parameter members, must be registered
79- ColumnParallelLinear w1_w2_proj_ {nullptr };
71+ FusedColumnParallelLinear gate_up_proj_ {nullptr };
8072 RowParallelLinear c_proj_{nullptr };
8173
8274 ActFunc act_{nullptr };
@@ -133,18 +125,6 @@ class QWenAttentionImpl : public Module {
133125 return c_proj_ (output);
134126 }
135127
136- // load the weight from the checkpoint
137- void load_state_dict (const StateDict& state_dict) {
138- // call each submodule's load_state_dict function
139- c_attn_->load_state_dict (state_dict.select (" c_attn." ));
140- c_proj_->load_state_dict (state_dict.select (" c_proj." ));
141- }
142-
143- void verify_loaded_weights (const std::string& prefix) const {
144- c_attn_->verify_loaded_weights (prefix + " c_attn." );
145- c_proj_->verify_loaded_weights (prefix + " c_proj." );
146- }
147-
148128 private:
149129 // parameter members, must be registered
150130 ColumnParallelLinear c_attn_{nullptr };
@@ -183,22 +163,6 @@ class QWenBlockImpl : public Module {
183163 return h + mlp_ (ln_2_ (h));
184164 }
185165
186- // load the weight from the checkpoint
187- void load_state_dict (const StateDict& state_dict) {
188- // call each submodule's load_state_dict function
189- attn_->load_state_dict (state_dict.select (" attn." ));
190- mlp_->load_state_dict (state_dict.select (" mlp." ));
191- ln_1_->load_state_dict (state_dict.select (" ln_1." ));
192- ln_2_->load_state_dict (state_dict.select (" ln_2." ));
193- }
194-
195- void verify_loaded_weights (const std::string& prefix) const {
196- attn_->verify_loaded_weights (prefix + " attn." );
197- mlp_->verify_loaded_weights (prefix + " mlp." );
198- ln_1_->verify_loaded_weights (prefix + " ln_1." );
199- ln_2_->verify_loaded_weights (prefix + " ln_2." );
200- }
201-
202166 private:
203167 // parameter members, must be registered
204168 QWenAttention attn_{nullptr };
@@ -226,7 +190,7 @@ class QWenModelImpl : public Module {
226190 handler_ = AttentionHandler::create_handler_with_rope (
227191 args, /* interleaved=*/ false , options);
228192
229- blocks_ = register_module (" layers " , ModuleList ());
193+ blocks_ = register_module (" h " , ModuleList ());
230194 layers_.reserve (args.n_layers ());
231195 for (int32_t i = 0 ; i < args.n_layers (); i++) {
232196 auto block =
@@ -254,26 +218,6 @@ class QWenModelImpl : public Module {
254218 return ln_f_ (h);
255219 }
256220
257- // load the weight from the checkpoint
258- void load_state_dict (const StateDict& state_dict) {
259- wte_->load_state_dict (state_dict.select (" wte." ));
260- // call each layer's load_state_dict function
261- for (int i = 0 ; i < layers_.size (); i++) {
262- layers_[i]->load_state_dict (
263- state_dict.select (" h." + std::to_string (i) + " ." ));
264- }
265- ln_f_->load_state_dict (state_dict.select (" ln_f." ));
266- }
267-
268- void verify_loaded_weights (const std::string& prefix) const {
269- wte_->verify_loaded_weights (prefix + " wte." );
270- for (int i = 0 ; i < layers_.size (); i++) {
271- layers_[i]->verify_loaded_weights (prefix + " h." + std::to_string (i) +
272- " ." );
273- }
274- ln_f_->verify_loaded_weights (prefix + " ln_f." );
275- }
276-
277221 private:
278222 // parameter members, must be registered
279223 ParallelEmbedding wte_{nullptr };
@@ -331,17 +275,6 @@ class QWenForCausalLMImpl : public Module {
331275 return lm_head_ (h);
332276 }
333277
334- // load the weight from the checkpoint
335- void load_state_dict (const StateDict& state_dict) {
336- transformer_->load_state_dict (state_dict.select (" transformer." ));
337- lm_head_->load_state_dict (state_dict.select (" lm_head." ));
338- }
339-
340- void verify_loaded_weights () const {
341- transformer_->verify_loaded_weights (" transformer." );
342- lm_head_->verify_loaded_weights (" lm_head." );
343- }
344-
345278 private:
346279 // parameter members, must be registered
347280 QWenModel transformer_{nullptr };
0 commit comments