Skip to content

Commit b77b20e

Browse files
authored
bugfix: improve Qwen2.5-VL acc when enable tensor parallel.
1 parent 2e76454 commit b77b20e

File tree

2 files changed

+11
-46
lines changed

2 files changed

+11
-46
lines changed

xllm/core/layers/npu/qwen2_5_vision_encoder_layer.cpp

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,8 @@ void Qwen2_5VisionEncoderImpl::pad_qkv_weights() {
158158
auto qkv_proj_weight_reshaped =
159159
qkv_proj_weight.reshape({num_heads_pre_rank, 3, 80, hidden_size});
160160

161-
auto first_half =
162-
qkv_proj_weight_reshaped.index({torch::indexing::Slice(),
163-
torch::indexing::Slice(),
164-
torch::indexing::Slice(0, 40),
165-
torch::indexing::Slice()});
166-
auto second_half = qkv_proj_weight_reshaped.index({torch::indexing::Slice(),
167-
torch::indexing::Slice(),
168-
torch::indexing::Slice(40),
169-
torch::indexing::Slice()});
161+
auto first_half = qkv_proj_weight_reshaped.slice(2, 0, 40);
162+
auto second_half = qkv_proj_weight_reshaped.slice(2, 40, 80);
170163

171164
auto first_half_padded = torch::nn::functional::pad(
172165
first_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24}));
@@ -182,12 +175,9 @@ void Qwen2_5VisionEncoderImpl::pad_qkv_weights() {
182175

183176
auto qkv_proj_bias_reshaped =
184177
qkv_proj_bias.reshape({num_heads_pre_rank, 3, 80});
185-
first_half = qkv_proj_bias_reshaped.index({torch::indexing::Slice(),
186-
torch::indexing::Slice(),
187-
torch::indexing::Slice(0, 40)});
188-
second_half = qkv_proj_bias_reshaped.index({torch::indexing::Slice(),
189-
torch::indexing::Slice(),
190-
torch::indexing::Slice(40)});
178+
first_half = qkv_proj_bias_reshaped.slice(2, 0, 40);
179+
second_half = qkv_proj_bias_reshaped.slice(2, 40, 80);
180+
191181
first_half_padded = torch::nn::functional::pad(
192182
first_half, torch::nn::functional::PadFuncOptions({0, 24}));
193183
second_half_padded = torch::nn::functional::pad(
@@ -202,31 +192,11 @@ void Qwen2_5VisionEncoderImpl::pad_qkv_weights() {
202192

203193
auto out_proj_weight = at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT];
204194

205-
if (encode_param_.worldSize == 1) {
206-
out_proj_weight =
207-
torch::nn::functional::pad(
208-
out_proj_weight.reshape({hidden_size, num_heads_pre_rank * 2, 40}),
209-
torch::nn::functional::PadFuncOptions({0, 24, 0, 0}))
210-
.reshape({hidden_size, num_heads_pre_rank * 128});
211-
} else if (encode_param_.worldSize > 1) {
212-
auto reshaped =
213-
out_proj_weight.reshape({num_heads_pre_rank, 80, hidden_size});
214-
215-
auto first_half = reshaped.slice(1, 0, 40);
216-
auto second_half = reshaped.slice(1, 40, 80);
217-
218-
auto first_half_padded = torch::nn::functional::pad(
219-
first_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24}));
220-
221-
auto second_half_padded = torch::nn::functional::pad(
222-
second_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24}));
223-
224-
auto out_proj_weight_padded =
225-
torch::cat({first_half_padded, second_half_padded}, 1);
226-
227-
out_proj_weight =
228-
out_proj_weight_padded.reshape({num_heads_pre_rank * 128, hidden_size});
229-
}
195+
out_proj_weight =
196+
torch::nn::functional::pad(
197+
out_proj_weight.reshape({hidden_size, num_heads_pre_rank * 2, 40}),
198+
torch::nn::functional::PadFuncOptions({0, 24, 0, 0}))
199+
.reshape({hidden_size, num_heads_pre_rank * 128});
230200
at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT] = out_proj_weight;
231201
}
232202
void Qwen2_5VisionEncoderImpl::merge_loaded_weights() {

xllm/models/qwen2_5_vl.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -695,12 +695,7 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module {
695695
Qwen2_5_VLForConditionalGenerationImpl(const Context& context)
696696
: model_args_(context.get_model_args()),
697697
options_(context.get_tensor_options()) {
698-
Context vision_context(ParallelArgs(0, 1, nullptr));
699-
vision_context.set_model_args(model_args_);
700-
vision_context.set_quant_args(context.get_quant_args());
701-
vision_context.set_tensor_options(options_);
702-
visual_ =
703-
register_module("visual", Qwen2_5_VisionTransformer(vision_context));
698+
visual_ = register_module("visual", Qwen2_5_VisionTransformer(context));
704699

705700
language_model_ =
706701
register_module("language_model", QWen2ForCausalLM(context));

0 commit comments

Comments
 (0)