Skip to content

Commit cd36cfc

Browse files
feat: support qwen2_5_vl, qwen3_vl, qwen3_vl_moe on mlu device. (jd-opensource#450)
Co-authored-by: guoxueting <[email protected]>
1 parent 9f4715f commit cd36cfc

36 files changed

+812
-83
lines changed

xllm/core/framework/parallel_state/parallel_state.cpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ torch::Tensor gather(const torch::Tensor& input,
7777
if (!process_group) {
7878
return input;
7979
}
80-
const auto world_size = process_group->world_size();
80+
const int32_t world_size = process_group->world_size();
8181
if (world_size == 1) {
8282
return input;
8383
}
8484

85-
const auto rank = process_group->rank();
85+
const int32_t rank = process_group->rank();
8686
std::vector<torch::Tensor> tensors(world_size);
8787
for (int64_t i = 0; i < world_size; ++i) {
8888
tensors[i] = torch::empty_like(input);
@@ -98,8 +98,8 @@ torch::Tensor gather(const torch::Tensor& input,
9898
if (!process_group) {
9999
return input;
100100
}
101-
const auto world_size = process_group->world_size();
102-
const auto rank = process_group->rank();
101+
const int32_t world_size = process_group->world_size();
102+
const int32_t rank = process_group->rank();
103103
if (world_size == 1) {
104104
return input;
105105
}
@@ -131,11 +131,42 @@ torch::Tensor gather(const torch::Tensor& input,
131131
gathered_input, max_num_tokens, token_num_list);
132132
}
133133

134+
torch::Tensor all_gather_interleaved(const torch::Tensor& input,
135+
ProcessGroup* process_group) {
136+
if (!process_group) {
137+
return input;
138+
}
139+
const int32_t world_size = process_group->world_size();
140+
const int32_t rank = process_group->rank();
141+
if (world_size == 1) {
142+
return input;
143+
}
144+
145+
std::vector<torch::Tensor> gathered_tensors(world_size);
146+
for (int64_t i = 0; i < world_size; ++i) {
147+
gathered_tensors[i] = torch::empty_like(input);
148+
}
149+
process_group->allgather(input, gathered_tensors);
150+
151+
int32_t dim = -1;
152+
size_t num_chunks = 3;
153+
std::vector<torch::Tensor> ordered_tensors;
154+
int64_t shard_size = input.size(dim) / num_chunks;
155+
for (size_t i = 0; i < num_chunks; ++i) {
156+
for (size_t j = 0; j < world_size; ++j) {
157+
auto shard_tensor =
158+
gathered_tensors[j].slice(dim, shard_size * i, shard_size * (i + 1));
159+
ordered_tensors.push_back(shard_tensor);
160+
}
161+
}
162+
return torch::cat(ordered_tensors, dim).contiguous();
163+
}
164+
134165
torch::Tensor reduce(torch::Tensor& input, ProcessGroup* process_group) {
135166
if (!process_group) {
136167
return input;
137168
}
138-
const auto world_size = process_group->world_size();
169+
const int32_t world_size = process_group->world_size();
139170
if (world_size == 1) {
140171
return input;
141172
}
@@ -149,20 +180,20 @@ torch::Tensor scatter(torch::Tensor input,
149180
if (!process_group) {
150181
return input;
151182
}
152-
const auto world_size = process_group->world_size();
183+
const int32_t world_size = process_group->world_size();
153184
if (world_size == 1) {
154185
return input;
155186
}
156187

157188
// get the size for last dimension
158-
const auto dim_size = input.size(dim);
189+
const int32_t dim_size = input.size(dim);
159190
CHECK(dim_size % world_size == 0)
160191
<< "dim_size " << dim_size << " cannot be divided by world_size "
161192
<< world_size;
162193

163194
// torch::split does not create contiguous tensors by default.
164195
const auto tensor_list = input.split(dim_size / world_size, dim);
165-
const auto rank = process_group->rank();
196+
const int32_t rank = process_group->rank();
166197
return tensor_list[rank];
167198
}
168199

xllm/core/framework/parallel_state/parallel_state.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ torch::Tensor gather(const torch::Tensor& input,
3333
ProcessGroup* process_group,
3434
const std::vector<int32_t>& token_num_list);
3535

36+
torch::Tensor all_gather_interleaved(const torch::Tensor& input,
37+
ProcessGroup* process_group);
38+
3639
torch::Tensor reduce(torch::Tensor& input, ProcessGroup* process_group);
3740

3841
torch::Tensor scatter(torch::Tensor input,

xllm/core/framework/state_dict/utils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,39 @@ void load_moe_fused_weight(const StateDict& state_dict,
243243
}
244244
}
245245

246+
void load_merged_weight(const StateDict& state_dict,
247+
const std::string& name,
248+
int64_t dim,
249+
int32_t rank,
250+
int32_t world_size,
251+
int32_t shard_tensor_count,
252+
int64_t shard_size,
253+
torch::Tensor& weight,
254+
bool& weight_is_loaded) {
255+
if (weight_is_loaded) {
256+
return;
257+
}
258+
const auto& tensor = state_dict.get_tensor(name);
259+
if (!tensor.defined()) {
260+
return;
261+
}
262+
CHECK_EQ(tensor.size(dim), shard_tensor_count * shard_size * world_size)
263+
<< name << "[" << dim << "] size mismatch for " << state_dict.prefix()
264+
<< name;
265+
std::vector<torch::Tensor> shard_tensors;
266+
for (size_t shard_id = 0; shard_id < shard_tensor_count; shard_id++) {
267+
int64_t shard_offset =
268+
shard_id * shard_size * world_size + rank * shard_size;
269+
shard_tensors.push_back(
270+
tensor.slice(dim, shard_offset, shard_offset + shard_size));
271+
}
272+
auto merged_weight = torch::cat(shard_tensors, dim);
273+
CHECK_EQ(weight.sizes(), merged_weight.sizes())
274+
<< "weight size mismatch for " << state_dict.prefix() << name;
275+
weight.copy_(merged_weight);
276+
weight_is_loaded = true;
277+
}
278+
246279
} // namespace weight
247280

248281
} // namespace xllm

xllm/core/framework/state_dict/utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ void load_moe_fused_weight(const StateDict& state_dict,
9393
bool& w1_is_loaded,
9494
bool& w3_is_loaded,
9595
bool& w13_is_loaded);
96+
97+
void load_merged_weight(const StateDict& state_dict,
98+
const std::string& name,
99+
int64_t dim,
100+
int32_t rank,
101+
int32_t world_size,
102+
int32_t shard_tensor_count,
103+
int64_t shard_size,
104+
torch::Tensor& weight,
105+
bool& weight_is_loaded);
106+
96107
} // namespace weight
97108

98109
// helper macros for defining and loading weights
@@ -173,4 +184,14 @@ void load_moe_fused_weight(const StateDict& state_dict,
173184
w3##_is_loaded_, \
174185
w13##_is_loaded_);
175186

187+
#define LOAD_MERGED_WEIGHT(name, dim) \
188+
weight::load_merged_weight(state_dict, \
189+
#name, \
190+
dim, \
191+
rank, \
192+
world_size, \
193+
shard_tensor_count, \
194+
shard_size, \
195+
name##_, \
196+
name##_is_loaded_);
176197
} // namespace xllm

xllm/core/kernels/mlu/active.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,20 @@ void active(const torch::Tensor& input,
2525
bool is_gated,
2626
int64_t start_expert_id,
2727
int64_t expert_size) {
28+
std::string hidden_act = act_mode;
29+
// TODO: act_mode gelu_pytorch_tanh not support yet.
30+
std::string gelu_approximate = "none";
31+
if (act_mode == "gelu_pytorch_tanh") {
32+
hidden_act = "gelu";
33+
gelu_approximate = "tanh";
34+
}
2835
tmo::torch_api::active(input,
2936
output,
3037
bias,
3138
cusum_token_count,
32-
act_mode,
39+
hidden_act,
3340
is_gated,
3441
start_expert_id,
3542
expert_size);
3643
}
37-
} // namespace xllm::kernel::mlu
44+
} // namespace xllm::kernel::mlu

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void apply_rotary(torch::Tensor& q,
3131
const torch::Tensor& sin,
3232
const torch::Tensor& cos,
3333
const std::optional<torch::Tensor>& position_ids,
34-
const torch::Tensor& cu_query_lens,
34+
const std::optional<torch::Tensor>& cu_query_lens,
3535
bool interleaved,
3636
bool discrete,
3737
bool dynamic_ntk,

xllm/core/kernels/mlu/rope.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ void apply_rotary(torch::Tensor& q,
2222
const torch::Tensor& sin,
2323
const torch::Tensor& cos,
2424
const std::optional<torch::Tensor>& position_ids,
25-
const torch::Tensor& cu_query_lens,
25+
const std::optional<torch::Tensor>& cu_query_lens,
2626
bool interleaved,
2727
bool discrete,
2828
bool dynamic_ntk,

xllm/core/kernels/param.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct RotaryParams {
5555
// Required in pack mode (when q/k are 3D). Size should be [batch_size + 1].
5656
// Note: In current MLU implementation, this is always passed to underlying
5757
// API.
58-
torch::Tensor cu_query_lens;
58+
std::optional<torch::Tensor> cu_query_lens;
5959
// Whether to use interleaved rotary embedding pattern.
6060
bool interleaved;
6161
// Whether to use discrete position mode. If true, position_ids must be

xllm/core/layers/common/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,27 @@ cc_library(
55
common_layers
66
HDRS
77
qwen2_attention.h
8+
qwen2_vision_attention.h
89
fuse_norm.h
910
rotary_embedding.h
1011
fused_moe.h
1112
dense_mlp.h
1213
qwen2_decoder_layer.h
14+
qwen2_5_vision_layer.h
1315
qwen3_moe_decoder_layer.h
1416
linear.h
1517
word_embedding_impl.h
1618
layer_utils.h
1719
indexer.h
1820
SRCS
1921
qwen2_attention.cpp
22+
qwen2_vision_attention.cpp
2023
fuse_norm.cpp
2124
rotary_embedding.cpp
2225
fused_moe.cpp
2326
dense_mlp.cpp
2427
qwen2_decoder_layer.cpp
28+
qwen2_5_vision_layer.cpp
2529
qwen3_moe_decoder_layer.cpp
2630
linear.cpp
2731
word_embedding_impl.cpp

xllm/core/layers/common/dense_mlp.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
5757
}
5858

5959
// 1. gate + up
60+
int64_t out_feature = is_gated_ ? intermediate_size_ * 2 : intermediate_size_;
6061
gate_up_proj_ =
6162
register_module("gate_up_proj",
6263
ColumnParallelLinear(hidden_size,
63-
intermediate_size_ * 2,
64+
out_feature,
6465
/*bias=*/has_bias,
6566
/*gather_output=*/false,
6667
quant_args,
@@ -111,5 +112,19 @@ void DenseMLPImpl::load_state_dict(const StateDict& state_dict) {
111112
down_proj_->load_state_dict(state_dict.get_dict_with_prefix("down_proj."));
112113
}
113114

115+
void DenseMLPImpl::load_state_dict(const StateDict& state_dict,
116+
const std::vector<std::string>& gate_up_name,
117+
const std::string& down_name) {
118+
if (is_gated_) {
119+
CHECK_EQ(gate_up_name.size(), 2);
120+
gate_up_proj_->load_state_dict(state_dict, gate_up_name);
121+
} else {
122+
CHECK_EQ(gate_up_name.size(), 1);
123+
gate_up_proj_->load_state_dict(
124+
state_dict.get_dict_with_prefix(gate_up_name[0]));
125+
}
126+
down_proj_->load_state_dict(state_dict.get_dict_with_prefix(down_name));
127+
}
128+
114129
} // namespace layer
115130
} // namespace xllm

0 commit comments

Comments
 (0)