@@ -158,15 +158,8 @@ void Qwen2_5VisionEncoderImpl::pad_qkv_weights() {
158158 auto qkv_proj_weight_reshaped =
159159 qkv_proj_weight.reshape ({num_heads_pre_rank, 3 , 80 , hidden_size});
160160
161- auto first_half =
162- qkv_proj_weight_reshaped.index ({torch::indexing::Slice (),
163- torch::indexing::Slice (),
164- torch::indexing::Slice (0 , 40 ),
165- torch::indexing::Slice ()});
166- auto second_half = qkv_proj_weight_reshaped.index ({torch::indexing::Slice (),
167- torch::indexing::Slice (),
168- torch::indexing::Slice (40 ),
169- torch::indexing::Slice ()});
161+ auto first_half = qkv_proj_weight_reshaped.slice (2 , 0 , 40 );
162+ auto second_half = qkv_proj_weight_reshaped.slice (2 , 40 , 80 );
170163
171164 auto first_half_padded = torch::nn::functional::pad (
172165 first_half, torch::nn::functional::PadFuncOptions ({0 , 0 , 0 , 24 }));
@@ -182,12 +175,9 @@ void Qwen2_5VisionEncoderImpl::pad_qkv_weights() {
182175
183176 auto qkv_proj_bias_reshaped =
184177 qkv_proj_bias.reshape ({num_heads_pre_rank, 3 , 80 });
185- first_half = qkv_proj_bias_reshaped.index ({torch::indexing::Slice (),
186- torch::indexing::Slice (),
187- torch::indexing::Slice (0 , 40 )});
188- second_half = qkv_proj_bias_reshaped.index ({torch::indexing::Slice (),
189- torch::indexing::Slice (),
190- torch::indexing::Slice (40 )});
178+ first_half = qkv_proj_bias_reshaped.slice (2 , 0 , 40 );
179+ second_half = qkv_proj_bias_reshaped.slice (2 , 40 , 80 );
180+
191181 first_half_padded = torch::nn::functional::pad (
192182 first_half, torch::nn::functional::PadFuncOptions ({0 , 24 }));
193183 second_half_padded = torch::nn::functional::pad (
@@ -202,31 +192,11 @@ void Qwen2_5VisionEncoderImpl::pad_qkv_weights() {
202192
203193 auto out_proj_weight = at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT];
204194
205- if (encode_param_.worldSize == 1 ) {
206- out_proj_weight =
207- torch::nn::functional::pad (
208- out_proj_weight.reshape ({hidden_size, num_heads_pre_rank * 2 , 40 }),
209- torch::nn::functional::PadFuncOptions ({0 , 24 , 0 , 0 }))
210- .reshape ({hidden_size, num_heads_pre_rank * 128 });
211- } else if (encode_param_.worldSize > 1 ) {
212- auto reshaped =
213- out_proj_weight.reshape ({num_heads_pre_rank, 80 , hidden_size});
214-
215- auto first_half = reshaped.slice (1 , 0 , 40 );
216- auto second_half = reshaped.slice (1 , 40 , 80 );
217-
218- auto first_half_padded = torch::nn::functional::pad (
219- first_half, torch::nn::functional::PadFuncOptions ({0 , 0 , 0 , 24 }));
220-
221- auto second_half_padded = torch::nn::functional::pad (
222- second_half, torch::nn::functional::PadFuncOptions ({0 , 0 , 0 , 24 }));
223-
224- auto out_proj_weight_padded =
225- torch::cat ({first_half_padded, second_half_padded}, 1 );
226-
227- out_proj_weight =
228- out_proj_weight_padded.reshape ({num_heads_pre_rank * 128 , hidden_size});
229- }
195+ out_proj_weight =
196+ torch::nn::functional::pad (
197+ out_proj_weight.reshape ({hidden_size, num_heads_pre_rank * 2 , 40 }),
198+ torch::nn::functional::PadFuncOptions ({0 , 24 , 0 , 0 }))
199+ .reshape ({hidden_size, num_heads_pre_rank * 128 });
230200 at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT] = out_proj_weight;
231201}
232202void Qwen2_5VisionEncoderImpl::merge_loaded_weights () {
0 commit comments