17
17
# This file is a part of the vllm-ascend project.
18
18
from typing import Optional
19
19
20
+ import torch
20
21
from torch import nn
21
22
from transformers import PretrainedConfig
22
23
from vllm .compilation .decorators import support_torch_compile
23
- from vllm .config import CacheConfig
24
+ from vllm .config import CacheConfig , CompilationLevel , VllmConfig
25
+ from vllm .distributed import get_tensor_model_parallel_world_size
26
+ from vllm .distributed .parallel_state import (get_dp_group , get_ep_group ,
27
+ get_tp_group )
28
+ from vllm .forward_context import get_forward_context
24
29
from vllm .model_executor .layers .layernorm import RMSNorm
30
+ from vllm .model_executor .layers .linear import ReplicatedLinear
25
31
from vllm .model_executor .layers .logits_processor import LogitsProcessor
26
32
from vllm .model_executor .layers .quantization import QuantizationConfig
27
33
from vllm .model_executor .layers .vocab_parallel_embedding import (
28
34
ParallelLMHead , VocabParallelEmbedding )
29
35
from vllm .model_executor .models .qwen3_moe import (Qwen3MoeAttention ,
30
36
Qwen3MoeDecoderLayer ,
31
37
Qwen3MoeForCausalLM ,
32
- Qwen3MoeMLP , Qwen3MoeModel )
38
+ Qwen3MoeMLP , Qwen3MoeModel ,
39
+ Qwen3MoeSparseMoeBlock )
33
40
from vllm .model_executor .models .utils import (
34
41
extract_layer_index , make_empty_intermediate_tensors_factory , make_layers ,
35
42
maybe_prefix )
36
43
37
- from vllm_ascend .ops .fused_moe import AscendSparseMoeBlock
38
- from vllm_ascend .platform import VllmConfig
44
+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
45
+
46
+
47
+ class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
48
+
49
+ def __init__ (
50
+ self ,
51
+ config : PretrainedConfig ,
52
+ quant_config : Optional [QuantizationConfig ] = None ,
53
+ prefix : str = "" ,
54
+ ):
55
+ nn .Module .__init__ (self )
56
+ self .tp_size = get_tensor_model_parallel_world_size ()
57
+ if self .tp_size > config .num_experts :
58
+ raise ValueError (
59
+ f"Tensor parallel size { self .tp_size } is greater than "
60
+ f"the number of experts { config .num_experts } ." )
61
+
62
+ self .gate = ReplicatedLinear (
63
+ config .hidden_size ,
64
+ config .num_experts ,
65
+ bias = False ,
66
+ quant_config = None ,
67
+ prefix = f"{ prefix } .gate" ,
68
+ )
69
+
70
+ self .experts = AscendFusedMoE (
71
+ num_experts = config .num_experts ,
72
+ top_k = config .num_experts_per_tok ,
73
+ hidden_size = config .hidden_size ,
74
+ intermediate_size = config .moe_intermediate_size ,
75
+ reduce_results = False ,
76
+ renormalize = config .norm_topk_prob ,
77
+ quant_config = quant_config ,
78
+ prefix = f"{ prefix } .experts" ,
79
+ )
80
+
81
+ self .top_k = config .num_experts_per_tok
82
+
83
+ self .dp_size = get_dp_group ().world_size
84
+
85
+ self .tp_group = get_tp_group ().device_group
86
+ self .tp_rank = get_tp_group ().rank_in_group
87
+ self .ep_group = get_ep_group ()
88
+
89
+ self .params_dtype = torch .get_default_dtype ()
90
+
91
+ def forward (
92
+ self ,
93
+ hidden_states ,
94
+ attn_metadata = None ,
95
+ ):
96
+ if attn_metadata is None :
97
+ attn_metadata = get_forward_context ().attn_metadata
98
+ # when profile runs, force experts to load balanced tokens
99
+ # to avoid high memory consumption on a single rank.
100
+ enable_force_load_balance = get_forward_context ().in_profile_run
101
+ is_prefill = get_forward_context ().with_prefill
102
+
103
+ # router_logits: (num_tokens, n_experts)
104
+ router_logits , _ = self .gate (hidden_states )
105
+
106
+ hidden_states = self .experts (
107
+ hidden_states = hidden_states ,
108
+ router_logits = router_logits ,
109
+ is_prefill = is_prefill ,
110
+ top_k = self .top_k ,
111
+ enable_force_load_balance = enable_force_load_balance ,
112
+ shared_experts = None ,
113
+ )
114
+
115
+ return hidden_states
39
116
40
117
41
118
class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
@@ -45,6 +122,7 @@ def __init__(
45
122
config : PretrainedConfig ,
46
123
cache_config : Optional [CacheConfig ] = None ,
47
124
quant_config : Optional [QuantizationConfig ] = None ,
125
+ vllm_config : Optional [VllmConfig ] = None ,
48
126
prefix : str = "" ,
49
127
) -> None :
50
128
@@ -73,12 +151,22 @@ def __init__(
73
151
layer_idx = extract_layer_index (prefix )
74
152
mlp_only_layers = ([] if not hasattr (config , "mlp_only_layers" ) else
75
153
config .mlp_only_layers )
154
+ use_aclgraph = (vllm_config is not None
155
+ and vllm_config .compilation_config .level
156
+ == CompilationLevel .PIECEWISE
157
+ and not vllm_config .model_config .enforce_eager )
76
158
if (layer_idx not in mlp_only_layers ) and (
77
159
config .num_experts > 0 and
78
160
(layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
79
- self .mlp = AscendSparseMoeBlock (config = config ,
80
- quant_config = quant_config ,
81
- prefix = f"{ prefix } .mlp" )
161
+ if not use_aclgraph :
162
+ # FIXME: custom sparse moe block doesn't work with aclgraph.
163
+ self .mlp = CustomSparseMoeBlock (config = config ,
164
+ quant_config = quant_config ,
165
+ prefix = f"{ prefix } .mlp" )
166
+ else :
167
+ self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
168
+ quant_config = quant_config ,
169
+ prefix = f"{ prefix } .mlp" )
82
170
else :
83
171
self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
84
172
intermediate_size = config .intermediate_size ,
@@ -115,6 +203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115
203
config = config ,
116
204
cache_config = cache_config ,
117
205
quant_config = quant_config ,
206
+ vllm_config = vllm_config ,
118
207
prefix = prefix ),
119
208
prefix = f"{ prefix } .layers" ,
120
209
)
0 commit comments