16
16
# limitations under the License.
17
17
# Adapted from vllm/model_executor/models/qwen3_moe.py
18
18
# This file is a part of the vllm-ascend project.
19
- from typing import Optional
19
+
20
+ from typing import Optional , Union
20
21
21
22
import torch
22
23
from torch import nn
23
24
from transformers import PretrainedConfig
24
25
from vllm .compilation .decorators import support_torch_compile
25
26
from vllm .config import CacheConfig , CompilationLevel , VllmConfig
26
- from vllm .distributed import get_tensor_model_parallel_world_size
27
+ from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
27
28
from vllm .distributed .parallel_state import (get_dp_group , get_ep_group ,
28
29
get_tp_group )
29
30
from vllm .forward_context import get_forward_context
44
45
from vllm .model_executor .models .utils import (
45
46
PPMissingLayer , extract_layer_index ,
46
47
make_empty_intermediate_tensors_factory , make_layers , maybe_prefix )
48
+ from vllm .sequence import IntermediateTensors
47
49
48
50
from vllm_ascend .ops .fused_moe import AscendFusedMoE
51
+ from vllm_ascend .ops .sequence_parallel import (MetadataForPadding ,
52
+ init_metadata_for_sp )
49
53
50
54
51
55
class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
@@ -96,6 +100,7 @@ def forward(
96
100
self ,
97
101
hidden_states ,
98
102
attn_metadata = None ,
103
+ _metadata_for_padding : Optional [MetadataForPadding ] = None ,
99
104
):
100
105
if attn_metadata is None :
101
106
attn_metadata = get_forward_context ().attn_metadata
@@ -114,6 +119,7 @@ def forward(
114
119
top_k = self .top_k ,
115
120
enable_force_load_balance = enable_force_load_balance ,
116
121
shared_experts = None ,
122
+ _metadata_for_padding = _metadata_for_padding ,
117
123
)
118
124
119
125
return hidden_states
@@ -155,14 +161,14 @@ def __init__(
155
161
layer_idx = extract_layer_index (prefix )
156
162
mlp_only_layers = ([] if not hasattr (config , "mlp_only_layers" ) else
157
163
config .mlp_only_layers )
158
- use_aclgraph = (vllm_config is not None
159
- and vllm_config .compilation_config .level
160
- == CompilationLevel .PIECEWISE
161
- and not vllm_config .model_config .enforce_eager )
164
+ self . use_aclgraph = (vllm_config is not None
165
+ and vllm_config .compilation_config .level
166
+ == CompilationLevel .PIECEWISE
167
+ and not vllm_config .model_config .enforce_eager )
162
168
if (layer_idx not in mlp_only_layers ) and (
163
169
config .num_experts > 0 and
164
170
(layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
165
- if not use_aclgraph :
171
+ if not self . use_aclgraph :
166
172
# FIXME: custom sparse moe block doesn't work with aclgraph.
167
173
self .mlp = CustomSparseMoeBlock (config = config ,
168
174
quant_config = quant_config ,
@@ -182,6 +188,60 @@ def __init__(
182
188
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
183
189
eps = config .rms_norm_eps )
184
190
191
+ self .enable_sequence_parallelism = (
192
+ vllm_config .compilation_config .pass_config .
193
+ enable_sequence_parallelism if vllm_config is not None else False )
194
+
195
+ def forward (
196
+ self ,
197
+ positions : torch .Tensor ,
198
+ hidden_states : torch .Tensor ,
199
+ residual : Optional [torch .Tensor ],
200
+ _metadata_for_padding : Optional [MetadataForPadding ] = None ,
201
+ ) -> torch .Tensor :
202
+
203
+ # To prevent precision issues during the decoder phase when only prefilling enables SP
204
+ if not self .enable_sequence_parallelism :
205
+ self .self_attn .o_proj .reduce_results = True
206
+ else :
207
+ self .self_attn .o_proj .reduce_results = not _metadata_for_padding .not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208
+
209
+ # Self Attention
210
+ if residual is None :
211
+ residual = hidden_states
212
+ if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
213
+ residual = _metadata_for_padding .padding_slice (residual )
214
+
215
+ hidden_states = self .input_layernorm (hidden_states )
216
+ else :
217
+ hidden_states , residual = self .input_layernorm (
218
+ hidden_states , residual )
219
+
220
+ if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
221
+ hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
222
+ hidden_states )
223
+
224
+ hidden_states = self .self_attn (
225
+ positions = positions ,
226
+ hidden_states = hidden_states ,
227
+ )
228
+
229
+ if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
230
+ hidden_states = _metadata_for_padding .padding_aligned_reduce_scatter (
231
+ hidden_states )
232
+
233
+ # Fully Connected
234
+ hidden_states , residual = self .post_attention_layernorm (
235
+ hidden_states , residual )
236
+
237
+ if not self .use_aclgraph :
238
+ hidden_states = self .mlp (
239
+ hidden_states , _metadata_for_padding = _metadata_for_padding )
240
+ else :
241
+ hidden_states = self .mlp (hidden_states )
242
+
243
+ return hidden_states , residual
244
+
185
245
186
246
@support_torch_compile
187
247
class CustomQwen3MoeModel (Qwen3MoeModel ):
@@ -216,6 +276,45 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
216
276
make_empty_intermediate_tensors_factory (
217
277
["hidden_states" , "residual" ], config .hidden_size ))
218
278
279
+ def forward (
280
+ self ,
281
+ input_ids : torch .Tensor ,
282
+ positions : torch .Tensor ,
283
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
284
+ inputs_embeds : Optional [torch .Tensor ] = None ,
285
+ _metadata_for_padding : Optional [MetadataForPadding ] = None ,
286
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
287
+ if get_pp_group ().is_first_rank :
288
+ if inputs_embeds is not None :
289
+ hidden_states = inputs_embeds
290
+ else :
291
+ hidden_states = self .get_input_embeddings (input_ids )
292
+ residual = None
293
+ else :
294
+ assert intermediate_tensors is not None
295
+ hidden_states = intermediate_tensors ["hidden_states" ]
296
+ residual = intermediate_tensors ["residual" ]
297
+ for i in range (self .start_layer , self .end_layer ):
298
+ layer = self .layers [i ]
299
+ hidden_states , residual = layer (
300
+ positions ,
301
+ hidden_states ,
302
+ residual ,
303
+ _metadata_for_padding = _metadata_for_padding )
304
+ if not get_pp_group ().is_last_rank :
305
+ return IntermediateTensors ({
306
+ "hidden_states" : hidden_states ,
307
+ "residual" : residual
308
+ })
309
+
310
+ hidden_states , _ = self .norm (hidden_states , residual )
311
+
312
+ if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
313
+ hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
314
+ hidden_states )
315
+
316
+ return hidden_states
317
+
219
318
220
319
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
221
320
packed_modules_mapping = {
@@ -253,6 +352,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
253
352
self .make_empty_intermediate_tensors = (
254
353
self .model .make_empty_intermediate_tensors )
255
354
355
+ self .enable_sequence_parallelism = vllm_config .compilation_config .pass_config .enable_sequence_parallelism
256
356
# Set MoE hyperparameters
257
357
self .expert_weights : list [torch .Tensor ] = []
258
358
@@ -273,3 +373,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
273
373
self .num_moe_layers = len (self .moe_layers )
274
374
self .num_expert_groups = 1
275
375
self .num_shared_experts = 0
376
+
377
+ def forward (
378
+ self ,
379
+ input_ids : torch .Tensor ,
380
+ positions : torch .Tensor ,
381
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
382
+ inputs_embeds : Optional [torch .Tensor ] = None ,
383
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
384
+ _metadata_for_padding = init_metadata_for_sp (
385
+ input_ids , self .enable_sequence_parallelism )
386
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
387
+ inputs_embeds , _metadata_for_padding )
388
+ return hidden_states
0 commit comments