1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ from dataclasses import dataclass , fields
4
+
5
+ import pytest
6
+ import torch
7
+ import torch .nn .functional as F
8
+ import triton_kernels .swiglu
9
+ from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig
10
+ from triton_kernels .numerics import InFlexData
11
+ from triton_kernels .numerics_details .mxfp import (downcast_to_mxfp ,
12
+ upcast_from_mxfp )
13
+ from triton_kernels .tensor import FP4 , convert_layout , wrap_torch_tensor
14
+ from triton_kernels .tensor_details import layout
15
+ from triton_kernels .testing import assert_close
16
+
17
+ from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
18
+ BatchedPrepareAndFinalize )
19
+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
20
+ from vllm .model_executor .layers .fused_moe .gpt_oss_triton_kernels_moe import (
21
+ BatchedOAITritonExperts , triton_kernel_moe_forward )
22
+ from vllm .model_executor .layers .fused_moe .modular_kernel import (
23
+ FusedMoEModularKernel )
24
+ from vllm .model_executor .layers .utils import shuffle_weight
25
+ from vllm .utils import round_up
26
+
27
+
28
+ def deshuffle (w : torch .Tensor ):
29
+ first = w [..., ::2 ]
30
+ second = w [..., 1 ::2 ]
31
+
32
+ deshuffled = torch .concat ((first , second ), dim = - 1 )
33
+ return deshuffled
34
+
35
+
36
+ def init_compute_data (M , K , N , E , a_dtype : str , w_dtype : str , num_warps : int ):
37
+ randbits = [torch .randperm (E ) for _ in range (M )]
38
+ x_list = [
39
+ (- 1 )** i *
40
+ ((16384 +
41
+ ((i * 512 ) % 4096 ) + bits ).to (torch .int16 ).view (torch .bfloat16 ))
42
+ for i , bits in enumerate (randbits )
43
+ ]
44
+ exp_data = torch .stack (x_list ).to (
45
+ device = "cuda" ) # simulating gate_output (M, E)
46
+
47
+ # create input tensor
48
+ x = torch .randn ((M , K ), dtype = torch .bfloat16 , device = "cuda" )
49
+ w1 = torch .randn ((E , 2 * N , K ), dtype = torch .bfloat16 , device = "cuda" )
50
+ w1_bias = torch .randn ((E , 2 * N ), dtype = torch .bfloat16 , device = "cuda" )
51
+
52
+ w2 = torch .randn ((E , K , N ), dtype = torch .bfloat16 , device = "cuda" )
53
+ w2_bias = torch .randn ((E , K ), dtype = torch .bfloat16 , device = "cuda" )
54
+
55
+ exp_data_tri = exp_data .clone ()
56
+ x_tri = x .clone ()
57
+ w1_tri = w1 .clone ()
58
+ w2_tri = w2 .clone ()
59
+
60
+ w1_bias_tri = w1_bias .clone ()
61
+ w2_bias_tri = w2_bias .clone ()
62
+ w1_bias_tri = w1_bias_tri .to (torch .float32 )
63
+ w2_bias_tri = w2_bias_tri .to (torch .float32 )
64
+
65
+ dtype_dict = {
66
+ "bf16" : torch .bfloat16 ,
67
+ "fp8_e4m3" : torch .float8_e4m3fn ,
68
+ "fp8_e5m2" : torch .float8_e5m2
69
+ }
70
+
71
+ x = x .to (dtype_dict [a_dtype ]).to (torch .bfloat16 )
72
+ if w_dtype != "mx4" :
73
+ # simulate quantization support on reference impl
74
+ w1 = w1 .to (dtype_dict [w_dtype ]).to (torch .bfloat16 )
75
+ w2 = w2 .to (dtype_dict [w_dtype ]).to (torch .bfloat16 )
76
+
77
+ # triton moe kernel use transposed shape for matmul
78
+ w1_tri = w1_tri .transpose (- 2 , - 1 )
79
+ w2_tri = w2_tri .transpose (- 2 , - 1 )
80
+
81
+ # shuffle weights
82
+ w1_tri = shuffle_weight (w1_tri )
83
+ w1_bias_tri = shuffle_weight (w1_bias_tri )
84
+
85
+ # quant triton_weights
86
+ x_tri = x .to (dtype_dict [a_dtype ])
87
+ if w_dtype != "mx4" :
88
+ pytest .skip ("NYI" )
89
+ else : # quantize to mx4
90
+ # careful on the padding here, the activation padding need to be
91
+ # multiple of 64, the actual engine is not implemented
92
+ w1_bottom_pad = round_up (w1_tri .shape [1 ], 64 ) - w1_tri .shape [1 ]
93
+ w1_right_pad = round_up (w1_tri .shape [2 ], 128 ) - w1_tri .shape [2 ]
94
+
95
+ w2_bottom_pad = w1_right_pad // 2
96
+ w2_right_pad = w1_bottom_pad
97
+
98
+ x_pad = w1_bottom_pad
99
+
100
+ w1_tri = F .pad (w1_tri , (0 , w1_right_pad , 0 , w1_bottom_pad , 0 , 0 ),
101
+ mode = "constant" ,
102
+ value = 0 )
103
+ w2_tri = F .pad (w2_tri , (0 , w2_right_pad , 0 , w2_bottom_pad , 0 , 0 ),
104
+ mode = "constant" ,
105
+ value = 0 )
106
+
107
+ w1_bias_tri = F .pad (w1_bias_tri , (0 , w1_right_pad , 0 , 0 ),
108
+ mode = "constant" ,
109
+ value = 0 )
110
+ w2_bias_tri = F .pad (w2_bias_tri , (0 , w2_right_pad , 0 , 0 ),
111
+ mode = "constant" ,
112
+ value = 0 )
113
+
114
+ x_tri = F .pad (x_tri , (0 , x_pad , 0 , 0 ), mode = "constant" , value = 0 )
115
+
116
+ w_layout , w_layout_opts = layout .make_default_matmul_mxfp4_w_layout (
117
+ mx_axis = 1 )
118
+ w_scale_layout , w_scale_layout_opts = (
119
+ layout .make_default_matmul_mxfp4_w_scale_layout (
120
+ mx_axis = 1 , num_warps = num_warps ))
121
+
122
+ w1_tri , w1_scale_tri = downcast_to_mxfp (w1_tri , torch .uint8 , axis = 1 )
123
+ w1 = upcast_from_mxfp (w1_tri , w1_scale_tri , torch .bfloat16 , axis = 1 )
124
+
125
+ w2_tri , w2_scale_tri = downcast_to_mxfp (w2_tri , torch .uint8 , axis = 1 )
126
+ w2 = upcast_from_mxfp (w2_tri , w2_scale_tri , torch .bfloat16 , axis = 1 )
127
+
128
+ w1_tri = convert_layout (wrap_torch_tensor (w1_tri , FP4 ), w_layout ,
129
+ ** w_layout_opts )
130
+ w1_scale_tri = convert_layout (wrap_torch_tensor (w1_scale_tri ),
131
+ w_scale_layout , ** w_scale_layout_opts )
132
+
133
+ w2_tri = convert_layout (wrap_torch_tensor (w2_tri , FP4 ), w_layout ,
134
+ ** w_layout_opts )
135
+ w2_scale_tri = convert_layout (wrap_torch_tensor (w2_scale_tri ),
136
+ w_scale_layout , ** w_scale_layout_opts )
137
+
138
+ pc1 = PrecisionConfig (weight_scale = w1_scale_tri ,
139
+ flex_ctx = FlexCtx (rhs_data = InFlexData ()))
140
+ pc2 = PrecisionConfig (weight_scale = w2_scale_tri ,
141
+ flex_ctx = FlexCtx (rhs_data = InFlexData ()))
142
+
143
+ # tucuate so the rest can run properly
144
+ w1 = w1 [..., :K , :2 * N ]
145
+ w2 = w2 [..., :N , :K ]
146
+
147
+ w1 = deshuffle (w1 )
148
+
149
+ w1 = w1 .transpose (- 1 , - 2 ).contiguous ()
150
+ w2 = w2 .transpose (- 1 , - 2 ).contiguous ()
151
+
152
+ return (x , w1 , w1_bias , w2 , w2_bias , exp_data , x_tri , w1_tri , w2_tri ,
153
+ exp_data_tri , w1_bias_tri , w2_bias_tri , pc1 , pc2 )
154
+
155
+
156
+ @dataclass
157
+ class ModelConfig :
158
+ num_hidden_layers : int = 36
159
+ num_experts : int = 128
160
+ experts_per_token : int = 4
161
+ vocab_size : int = 201088
162
+ hidden_size : int = 2880
163
+ intermediate_size : int = 2880
164
+ head_dim : int = 64
165
+ num_attention_heads : int = 64
166
+ num_key_value_heads : int = 8
167
+ sliding_window : int = 128
168
+ initial_context_length : int = 4096
169
+ rope_theta : float = 150000.0
170
+ rope_scaling_factor : float = 32.0
171
+ rope_ntk_alpha : float = 1.0
172
+ rope_ntk_beta : float = 32.0
173
+
174
+
175
+ def swiglu (x , alpha : float = 1.702 , limit : float = 1.0 ):
176
+ # Note we add an extra bias of 1 to the linear layer
177
+ x_glu , x_linear = torch .chunk (x , 2 , dim = - 1 )
178
+ if limit is not None :
179
+ x_glu = x_glu .clamp (max = limit )
180
+ out_glu = x_glu * torch .sigmoid (alpha * x_glu )
181
+ if limit is not None :
182
+ x_linear = x_linear .clamp (min = - limit , max = limit )
183
+ return out_glu * (x_linear + 1 )
184
+
185
+
186
+ def oai_moe_forward (
187
+ hidden_states : torch .Tensor , # (M, K)
188
+ w1 : torch .Tensor , # (E, 2N)
189
+ w1_bias : torch .Tensor , # (E, 2N, K)
190
+ w2 : torch .Tensor , # (E, K, N)
191
+ w2_bias : torch .Tensor , # (E, N)
192
+ gating_output : torch .Tensor , # (M, E)
193
+ topk : int ):
194
+ # model.py 309:330, assuming gating and norm
195
+ t = hidden_states
196
+ experts = torch .topk (gating_output , k = topk , dim = - 1 , sorted = True )
197
+ expert_weights = torch .nn .functional .softmax (experts .values , dim = 1 )
198
+ expert_indices = experts .indices
199
+
200
+ # MLP #1
201
+ mlp1_weight = w1 [expert_indices , ...]
202
+ mlp1_bias = w1_bias [expert_indices , ...]
203
+ t = torch .einsum ("beck,bk->bec" , mlp1_weight , t ) + mlp1_bias
204
+ t = swiglu (t , limit = 7 )
205
+
206
+ # MLP #2
207
+ mlp2_weight = w2 [expert_indices , ...]
208
+ mlp2_bias = w2_bias [expert_indices , ...]
209
+ t = torch .einsum ("beck,bek->bec" , mlp2_weight , t )
210
+ t += mlp2_bias
211
+
212
+ # Weighted sum of experts
213
+ t = torch .einsum ("bec,be->bc" , t , expert_weights )
214
+
215
+ return t
216
+
217
+
218
+ @dataclass
219
+ class Case :
220
+ a_dtype : str
221
+ w_dtype : str
222
+
223
+
224
+ @pytest .mark .parametrize (
225
+ ", " .join (f .name for f in fields (Case )),
226
+ [
227
+ tuple (getattr (case , f .name ) for f in fields (Case )) for case in [
228
+ # Case(a_dtype="bf16", w_dtype="bf16"),
229
+ # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
230
+ Case (a_dtype = "bf16" , w_dtype = "mx4" )
231
+ ]
232
+ ],
233
+ )
234
+ @pytest .mark .parametrize ("num_token" , [2 ])
235
+ @pytest .mark .parametrize ("tp" , [1 , 2 , 4 , 8 ])
236
+ def test_equiv (num_token , a_dtype , w_dtype , tp ):
237
+ M = num_token
238
+ E = ModelConfig .num_experts
239
+ K = ModelConfig .hidden_size
240
+ N = ModelConfig .intermediate_size // tp
241
+ topk = ModelConfig .experts_per_token
242
+
243
+ x , w1 , w1_bias , w2 , w2_bias , exp_data , \
244
+ x_tri , w1_tri , w2_tri , exp_data_tri , w1_bias_tri ,\
245
+ w2_bias_tri , pc1 , pc2 = init_compute_data (
246
+ M , K , N , E , a_dtype , w_dtype , num_warps = 8 )
247
+
248
+ out_triton_monolithic = triton_kernel_moe_forward (
249
+ hidden_states = x_tri ,
250
+ w1 = w1_tri ,
251
+ w2 = w2_tri ,
252
+ gating_output = exp_data_tri ,
253
+ topk = topk ,
254
+ renormalize = True ,
255
+ w1_bias = w1_bias_tri ,
256
+ w2_bias = w2_bias_tri ,
257
+ w1_precision = pc1 ,
258
+ w2_precision = pc2 )
259
+ out_triton_monolithic = out_triton_monolithic [..., :K ]
260
+
261
+ out_ref = oai_moe_forward (hidden_states = x ,
262
+ w1 = w1 ,
263
+ w1_bias = w1_bias ,
264
+ w2 = w2 ,
265
+ w2_bias = w2_bias ,
266
+ gating_output = exp_data ,
267
+ topk = topk )
268
+ assert_close (ref = out_ref ,
269
+ tri = out_triton_monolithic ,
270
+ maxtol = 0.025 ,
271
+ rmstol = 0.005 )
272
+
273
+
274
+ def batched_moe (a : torch .Tensor , w1 , w2 , gating_output : torch .Tensor ,
275
+ topk : int , renormalize : bool , w1_bias : torch .Tensor ,
276
+ w2_bias : torch .Tensor , w1_precision : PrecisionConfig ,
277
+ w2_precision : PrecisionConfig ) -> torch .Tensor :
278
+ max_num_tokens = round_up (a .shape [0 ], 64 )
279
+
280
+ fused_experts = FusedMoEModularKernel (
281
+ BatchedPrepareAndFinalize (max_num_tokens ,
282
+ num_dispatchers = 1 ,
283
+ num_local_experts = w1 .shape [0 ],
284
+ rank = 0 ),
285
+ BatchedOAITritonExperts (
286
+ None ,
287
+ max_num_tokens = max_num_tokens ,
288
+ num_dispatchers = 1 ,
289
+ w1_precision = w1_precision ,
290
+ w2_precision = w2_precision ,
291
+ ),
292
+ )
293
+
294
+ extra_expert_args = {
295
+ "w1_bias" : w1_bias ,
296
+ "w2_bias" : w2_bias ,
297
+ }
298
+
299
+ topk_weight , topk_ids , _ = fused_topk (a , gating_output , topk , renormalize )
300
+
301
+ return fused_experts (
302
+ a ,
303
+ w1 ,
304
+ w2 ,
305
+ topk_weight ,
306
+ topk_ids ,
307
+ extra_expert_args = extra_expert_args ,
308
+ )
309
+
310
+
311
+ @pytest .mark .parametrize (
312
+ ", " .join (f .name for f in fields (Case )),
313
+ [
314
+ tuple (getattr (case , f .name ) for f in fields (Case )) for case in [
315
+ # Case(a_dtype="bf16", w_dtype="bf16"),
316
+ # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
317
+ Case (a_dtype = "bf16" , w_dtype = "mx4" )
318
+ ]
319
+ ],
320
+ )
321
+ @pytest .mark .parametrize ("num_token" , [64 ])
322
+ @pytest .mark .parametrize ("ep" , [1 , 2 , 4 , 8 ])
323
+ def test_triton_kernel_batched_moe (num_token , a_dtype , w_dtype , ep ):
324
+ M = num_token
325
+ E = ModelConfig .num_experts // ep
326
+ K = ModelConfig .hidden_size
327
+ N = ModelConfig .intermediate_size
328
+ topk = ModelConfig .experts_per_token
329
+
330
+ x , w1 , w1_bias , w2 , w2_bias , exp_data , \
331
+ x_tri , w1_tri , w2_tri , exp_data_tri , w1_bias_tri , \
332
+ w2_bias_tri , pc1 , pc2 = init_compute_data (
333
+ M , K , N , E , a_dtype , w_dtype , num_warps = 4 )
334
+
335
+ out_tri = batched_moe (a = x_tri ,
336
+ w1 = w1_tri ,
337
+ w2 = w2_tri ,
338
+ gating_output = exp_data_tri ,
339
+ topk = topk ,
340
+ renormalize = True ,
341
+ w1_bias = w1_bias_tri ,
342
+ w2_bias = w2_bias_tri ,
343
+ w1_precision = pc1 ,
344
+ w2_precision = pc2 )
345
+ out_tri = out_tri [..., :K ]
346
+
347
+ out_ref = oai_moe_forward (hidden_states = x ,
348
+ w1 = w1 ,
349
+ w1_bias = w1_bias ,
350
+ w2 = w2 ,
351
+ w2_bias = w2_bias ,
352
+ gating_output = exp_data ,
353
+ topk = topk )
354
+ assert_close (ref = out_ref , tri = out_tri , maxtol = 0.025 , rmstol = 0.005 )
355
+
356
+
357
+ def test_unit_shuffle ():
358
+ N = ModelConfig .intermediate_size
359
+ K = ModelConfig .hidden_size
360
+ m = torch .randn ((K , 2 * N ), dtype = torch .bfloat16 , device = "cuda" )
361
+
362
+ x = torch .randn (K , dtype = torch .bfloat16 , device = "cuda" )
363
+
364
+ m_shuffled = shuffle_weight (m )
365
+
366
+ out_ref = x @ m
367
+ out_ref = swiglu (out_ref , limit = 1.0 )
368
+
369
+ out = x @ m_shuffled
370
+ out = triton_kernels .swiglu .swiglu_torch (
371
+ out ,
372
+ alpha = 1.702 ,
373
+ precision_config = triton_kernels .swiglu .PrecisionConfig (limit = 1.0 ))
374
+
375
+ assert_close (ref = out_ref , tri = out )
0 commit comments