@@ -36,16 +36,18 @@ class LlamaMLPImpl : public Module {
3636 const int64_t intermediate_size = args.intermediate_size ();
3737
3838 // register the weight parameter
39- // gate_up_proj_ = register_module(
40- // "gate_up_proj",
41- // FusedColumnParallelLinear(
42- // hidden_size,
43- // std::vector<int64_t>{intermediate_size, intermediate_size},
44- // /*bias=*/false,
45- // /*gather_output=*/false,
46- // quant_args,
47- // parallel_args,
48- // options));
39+ gate_up_proj_ = register_module (
40+ " gate_up_proj" ,
41+ FusedColumnParallelLinear (
42+ hidden_size,
43+ std::vector<int64_t >{intermediate_size, intermediate_size},
44+ std::vector<std::string>{" gate_proj." , " up_proj." },
45+ /* bias=*/ false ,
46+ /* gather_output=*/ false ,
47+ quant_args,
48+ parallel_args,
49+ options),
50+ /* selector=*/ nullptr );
4951
5052 down_proj_ =
5153 register_module (" down_proj" ,
@@ -63,18 +65,6 @@ class LlamaMLPImpl : public Module {
6365 return down_proj_ (act_func_ (gate_up[0 ]) * gate_up[1 ]);
6466 }
6567
66- // // load the weight from the checkpoint
67- // void load_state_dict(const StateDict& state_dict) {
68- // // call each submodule's load_state_dict function
69- // gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
70- // down_proj_->load_state_dict(state_dict.select("down_proj."));
71- // }
72-
73- // void verify_loaded_weights(const std::string& prefix) const {
74- // gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
75- // down_proj_->verify_loaded_weights(prefix + "down_proj.");
76- // }
77-
7868 private:
7969 // parameter members, must be registered
8070 FusedColumnParallelLinear gate_up_proj_{nullptr };
@@ -102,16 +92,20 @@ class LlamaAttentionImpl : public Module {
10292 std::max<int64_t >(1 , n_kv_heads / world_size);
10393
10494 // register submodules
105- // qkv_proj_ = register_module("qkv_proj",
106- // QKVColumnParallelLinear(hidden_size,
107- // n_heads,
108- // n_kv_heads,
109- // head_dim,
110- // /*bias=*/false,
111- // /*gather_output=*/false,
112- // quant_args,
113- // parallel_args,
114- // options));
95+ qkv_proj_ = register_module (
96+ " qkv_proj" ,
97+ QKVColumnParallelLinear (
98+ hidden_size,
99+ n_heads,
100+ n_kv_heads,
101+ head_dim,
102+ std::vector<std::string>{" q_proj." , " k_proj." , " v_proj." },
103+ /* bias=*/ false ,
104+ /* gather_output=*/ false ,
105+ quant_args,
106+ parallel_args,
107+ options),
108+ /* selector=*/ nullptr );
115109
116110 o_proj_ = register_module (" o_proj" ,
117111 RowParallelLinear (hidden_size,
@@ -141,20 +135,6 @@ class LlamaAttentionImpl : public Module {
141135 return o_proj_ (output);
142136 }
143137
144- // // load the weight from the checkpoint
145- // void load_state_dict(const StateDict& state_dict) {
146- // // call each submodule's load_state_dict function
147- // qkv_proj_->load_state_dict(
148- // state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.",
149- // "v_proj."});
150- // o_proj_->load_state_dict(state_dict.select("o_proj."));
151- // }
152-
153- // void verify_loaded_weights(const std::string& prefix) const {
154- // qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
155- // o_proj_->verify_loaded_weights(prefix + "o_proj.");
156- // }
157-
158138 private:
159139 // parameter members, must be registered
160140 QKVColumnParallelLinear qkv_proj_{nullptr };
0 commit comments