5
5
import pytest
6
6
import torch
7
7
import torch .nn .functional as F
8
+
9
+ from vllm .utils import has_triton_kernels
10
+
11
+ if not has_triton_kernels ():
12
+ pytest .skip (
13
+ "triton_kernels not found, skipping all related tests" ,
14
+ allow_module_level = True ,
15
+ )
16
+
8
17
import triton_kernels .swiglu
9
18
from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig
10
19
from triton_kernels .numerics import InFlexData
@@ -65,7 +74,7 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
65
74
dtype_dict = {
66
75
"bf16" : torch .bfloat16 ,
67
76
"fp8_e4m3" : torch .float8_e4m3fn ,
68
- "fp8_e5m2" : torch .float8_e5m2
77
+ "fp8_e5m2" : torch .float8_e5m2 ,
69
78
}
70
79
71
80
x = x .to (dtype_dict [a_dtype ]).to (torch .bfloat16 )
@@ -97,12 +106,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
97
106
98
107
x_pad = w1_bottom_pad
99
108
100
- w1_tri = F .pad (w1_tri , (0 , w1_right_pad , 0 , w1_bottom_pad , 0 , 0 ),
101
- mode = "constant" ,
102
- value = 0 )
103
- w2_tri = F .pad (w2_tri , (0 , w2_right_pad , 0 , w2_bottom_pad , 0 , 0 ),
104
- mode = "constant" ,
105
- value = 0 )
109
+ w1_tri = F .pad (
110
+ w1_tri ,
111
+ (0 , w1_right_pad , 0 , w1_bottom_pad , 0 , 0 ),
112
+ mode = "constant" ,
113
+ value = 0 ,
114
+ )
115
+ w2_tri = F .pad (
116
+ w2_tri ,
117
+ (0 , w2_right_pad , 0 , w2_bottom_pad , 0 , 0 ),
118
+ mode = "constant" ,
119
+ value = 0 ,
120
+ )
106
121
107
122
w1_bias_tri = F .pad (w1_bias_tri , (0 , w1_right_pad , 0 , 0 ),
108
123
mode = "constant" ,
@@ -127,13 +142,19 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
127
142
128
143
w1_tri = convert_layout (wrap_torch_tensor (w1_tri , FP4 ), w_layout ,
129
144
** w_layout_opts )
130
- w1_scale_tri = convert_layout (wrap_torch_tensor (w1_scale_tri ),
131
- w_scale_layout , ** w_scale_layout_opts )
145
+ w1_scale_tri = convert_layout (
146
+ wrap_torch_tensor (w1_scale_tri ),
147
+ w_scale_layout ,
148
+ ** w_scale_layout_opts ,
149
+ )
132
150
133
151
w2_tri = convert_layout (wrap_torch_tensor (w2_tri , FP4 ), w_layout ,
134
152
** w_layout_opts )
135
- w2_scale_tri = convert_layout (wrap_torch_tensor (w2_scale_tri ),
136
- w_scale_layout , ** w_scale_layout_opts )
153
+ w2_scale_tri = convert_layout (
154
+ wrap_torch_tensor (w2_scale_tri ),
155
+ w_scale_layout ,
156
+ ** w_scale_layout_opts ,
157
+ )
137
158
138
159
pc1 = PrecisionConfig (weight_scale = w1_scale_tri ,
139
160
flex_ctx = FlexCtx (rhs_data = InFlexData ()))
@@ -149,8 +170,22 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
149
170
w1 = w1 .transpose (- 1 , - 2 ).contiguous ()
150
171
w2 = w2 .transpose (- 1 , - 2 ).contiguous ()
151
172
152
- return (x , w1 , w1_bias , w2 , w2_bias , exp_data , x_tri , w1_tri , w2_tri ,
153
- exp_data_tri , w1_bias_tri , w2_bias_tri , pc1 , pc2 )
173
+ return (
174
+ x ,
175
+ w1 ,
176
+ w1_bias ,
177
+ w2 ,
178
+ w2_bias ,
179
+ exp_data ,
180
+ x_tri ,
181
+ w1_tri ,
182
+ w2_tri ,
183
+ exp_data_tri ,
184
+ w1_bias_tri ,
185
+ w2_bias_tri ,
186
+ pc1 ,
187
+ pc2 ,
188
+ )
154
189
155
190
156
191
@dataclass
@@ -184,13 +219,14 @@ def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
184
219
185
220
186
221
def oai_moe_forward (
187
- hidden_states : torch .Tensor , # (M, K)
188
- w1 : torch .Tensor , # (E, 2N)
189
- w1_bias : torch .Tensor , # (E, 2N, K)
190
- w2 : torch .Tensor , # (E, K, N)
191
- w2_bias : torch .Tensor , # (E, N)
192
- gating_output : torch .Tensor , # (M, E)
193
- topk : int ):
222
+ hidden_states : torch .Tensor , # (M, K)
223
+ w1 : torch .Tensor , # (E, 2N)
224
+ w1_bias : torch .Tensor , # (E, 2N, K)
225
+ w2 : torch .Tensor , # (E, K, N)
226
+ w2_bias : torch .Tensor , # (E, N)
227
+ gating_output : torch .Tensor , # (M, E)
228
+ topk : int ,
229
+ ):
194
230
# model.py 309:330, assuming gating and norm
195
231
t = hidden_states
196
232
experts = torch .topk (gating_output , k = topk , dim = - 1 , sorted = True )
@@ -240,10 +276,22 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
240
276
N = ModelConfig .intermediate_size // tp
241
277
topk = ModelConfig .experts_per_token
242
278
243
- x , w1 , w1_bias , w2 , w2_bias , exp_data , \
244
- x_tri , w1_tri , w2_tri , exp_data_tri , w1_bias_tri ,\
245
- w2_bias_tri , pc1 , pc2 = init_compute_data (
246
- M , K , N , E , a_dtype , w_dtype , num_warps = 8 )
279
+ (
280
+ x ,
281
+ w1 ,
282
+ w1_bias ,
283
+ w2 ,
284
+ w2_bias ,
285
+ exp_data ,
286
+ x_tri ,
287
+ w1_tri ,
288
+ w2_tri ,
289
+ exp_data_tri ,
290
+ w1_bias_tri ,
291
+ w2_bias_tri ,
292
+ pc1 ,
293
+ pc2 ,
294
+ ) = init_compute_data (M , K , N , E , a_dtype , w_dtype , num_warps = 8 )
247
295
248
296
out_triton_monolithic = triton_kernel_moe_forward (
249
297
hidden_states = x_tri ,
@@ -255,33 +303,46 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
255
303
w1_bias = w1_bias_tri ,
256
304
w2_bias = w2_bias_tri ,
257
305
w1_precision = pc1 ,
258
- w2_precision = pc2 )
306
+ w2_precision = pc2 ,
307
+ )
259
308
out_triton_monolithic = out_triton_monolithic [..., :K ]
260
309
261
- out_ref = oai_moe_forward (hidden_states = x ,
262
- w1 = w1 ,
263
- w1_bias = w1_bias ,
264
- w2 = w2 ,
265
- w2_bias = w2_bias ,
266
- gating_output = exp_data ,
267
- topk = topk )
310
+ out_ref = oai_moe_forward (
311
+ hidden_states = x ,
312
+ w1 = w1 ,
313
+ w1_bias = w1_bias ,
314
+ w2 = w2 ,
315
+ w2_bias = w2_bias ,
316
+ gating_output = exp_data ,
317
+ topk = topk ,
318
+ )
268
319
assert_close (ref = out_ref ,
269
320
tri = out_triton_monolithic ,
270
321
maxtol = 0.025 ,
271
322
rmstol = 0.005 )
272
323
273
324
274
- def batched_moe (a : torch .Tensor , w1 , w2 , gating_output : torch .Tensor ,
275
- topk : int , renormalize : bool , w1_bias : torch .Tensor ,
276
- w2_bias : torch .Tensor , w1_precision : PrecisionConfig ,
277
- w2_precision : PrecisionConfig ) -> torch .Tensor :
325
+ def batched_moe (
326
+ a : torch .Tensor ,
327
+ w1 ,
328
+ w2 ,
329
+ gating_output : torch .Tensor ,
330
+ topk : int ,
331
+ renormalize : bool ,
332
+ w1_bias : torch .Tensor ,
333
+ w2_bias : torch .Tensor ,
334
+ w1_precision : PrecisionConfig ,
335
+ w2_precision : PrecisionConfig ,
336
+ ) -> torch .Tensor :
278
337
max_num_tokens = round_up (a .shape [0 ], 64 )
279
338
280
339
fused_experts = FusedMoEModularKernel (
281
- BatchedPrepareAndFinalize (max_num_tokens ,
282
- num_dispatchers = 1 ,
283
- num_local_experts = w1 .shape [0 ],
284
- rank = 0 ),
340
+ BatchedPrepareAndFinalize (
341
+ max_num_tokens ,
342
+ num_dispatchers = 1 ,
343
+ num_local_experts = w1 .shape [0 ],
344
+ rank = 0 ,
345
+ ),
285
346
BatchedOAITritonExperts (
286
347
None ,
287
348
max_num_tokens = max_num_tokens ,
@@ -327,30 +388,46 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
327
388
N = ModelConfig .intermediate_size
328
389
topk = ModelConfig .experts_per_token
329
390
330
- x , w1 , w1_bias , w2 , w2_bias , exp_data , \
331
- x_tri , w1_tri , w2_tri , exp_data_tri , w1_bias_tri , \
332
- w2_bias_tri , pc1 , pc2 = init_compute_data (
333
- M , K , N , E , a_dtype , w_dtype , num_warps = 4 )
334
-
335
- out_tri = batched_moe (a = x_tri ,
336
- w1 = w1_tri ,
337
- w2 = w2_tri ,
338
- gating_output = exp_data_tri ,
339
- topk = topk ,
340
- renormalize = True ,
341
- w1_bias = w1_bias_tri ,
342
- w2_bias = w2_bias_tri ,
343
- w1_precision = pc1 ,
344
- w2_precision = pc2 )
391
+ (
392
+ x ,
393
+ w1 ,
394
+ w1_bias ,
395
+ w2 ,
396
+ w2_bias ,
397
+ exp_data ,
398
+ x_tri ,
399
+ w1_tri ,
400
+ w2_tri ,
401
+ exp_data_tri ,
402
+ w1_bias_tri ,
403
+ w2_bias_tri ,
404
+ pc1 ,
405
+ pc2 ,
406
+ ) = init_compute_data (M , K , N , E , a_dtype , w_dtype , num_warps = 4 )
407
+
408
+ out_tri = batched_moe (
409
+ a = x_tri ,
410
+ w1 = w1_tri ,
411
+ w2 = w2_tri ,
412
+ gating_output = exp_data_tri ,
413
+ topk = topk ,
414
+ renormalize = True ,
415
+ w1_bias = w1_bias_tri ,
416
+ w2_bias = w2_bias_tri ,
417
+ w1_precision = pc1 ,
418
+ w2_precision = pc2 ,
419
+ )
345
420
out_tri = out_tri [..., :K ]
346
421
347
- out_ref = oai_moe_forward (hidden_states = x ,
348
- w1 = w1 ,
349
- w1_bias = w1_bias ,
350
- w2 = w2 ,
351
- w2_bias = w2_bias ,
352
- gating_output = exp_data ,
353
- topk = topk )
422
+ out_ref = oai_moe_forward (
423
+ hidden_states = x ,
424
+ w1 = w1 ,
425
+ w1_bias = w1_bias ,
426
+ w2 = w2 ,
427
+ w2_bias = w2_bias ,
428
+ gating_output = exp_data ,
429
+ topk = topk ,
430
+ )
354
431
assert_close (ref = out_ref , tri = out_tri , maxtol = 0.025 , rmstol = 0.005 )
355
432
356
433
@@ -370,6 +447,7 @@ def test_unit_shuffle():
370
447
out = triton_kernels .swiglu .swiglu_torch (
371
448
out ,
372
449
alpha = 1.702 ,
373
- precision_config = triton_kernels .swiglu .PrecisionConfig (limit = 1.0 ))
450
+ precision_config = triton_kernels .swiglu .PrecisionConfig (limit = 1.0 ),
451
+ )
374
452
375
- assert_close (ref = out_ref , tri = out )
453
+ assert_close (ref = out_ref , tri = out )
0 commit comments