Skip to content

Commit 0edc0cd

Browse files
authored
[Bugfix] Fix CI moe kernel failure (#22556)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 7920e9b commit 0edc0cd

File tree

1 file changed

+142
-64
lines changed

1 file changed

+142
-64
lines changed

tests/kernels/moe/test_gpt_oss_triton_kernels.py

Lines changed: 142 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
import pytest
66
import torch
77
import torch.nn.functional as F
8+
9+
from vllm.utils import has_triton_kernels
10+
11+
if not has_triton_kernels():
12+
pytest.skip(
13+
"triton_kernels not found, skipping all related tests",
14+
allow_module_level=True,
15+
)
16+
817
import triton_kernels.swiglu
918
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
1019
from triton_kernels.numerics import InFlexData
@@ -65,7 +74,7 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
6574
dtype_dict = {
6675
"bf16": torch.bfloat16,
6776
"fp8_e4m3": torch.float8_e4m3fn,
68-
"fp8_e5m2": torch.float8_e5m2
77+
"fp8_e5m2": torch.float8_e5m2,
6978
}
7079

7180
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
@@ -97,12 +106,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
97106

98107
x_pad = w1_bottom_pad
99108

100-
w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
101-
mode="constant",
102-
value=0)
103-
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
104-
mode="constant",
105-
value=0)
109+
w1_tri = F.pad(
110+
w1_tri,
111+
(0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
112+
mode="constant",
113+
value=0,
114+
)
115+
w2_tri = F.pad(
116+
w2_tri,
117+
(0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
118+
mode="constant",
119+
value=0,
120+
)
106121

107122
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
108123
mode="constant",
@@ -127,13 +142,19 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
127142

128143
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
129144
**w_layout_opts)
130-
w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri),
131-
w_scale_layout, **w_scale_layout_opts)
145+
w1_scale_tri = convert_layout(
146+
wrap_torch_tensor(w1_scale_tri),
147+
w_scale_layout,
148+
**w_scale_layout_opts,
149+
)
132150

133151
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
134152
**w_layout_opts)
135-
w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri),
136-
w_scale_layout, **w_scale_layout_opts)
153+
w2_scale_tri = convert_layout(
154+
wrap_torch_tensor(w2_scale_tri),
155+
w_scale_layout,
156+
**w_scale_layout_opts,
157+
)
137158

138159
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
139160
flex_ctx=FlexCtx(rhs_data=InFlexData()))
@@ -149,8 +170,22 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
149170
w1 = w1.transpose(-1, -2).contiguous()
150171
w2 = w2.transpose(-1, -2).contiguous()
151172

152-
return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri,
153-
exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2)
173+
return (
174+
x,
175+
w1,
176+
w1_bias,
177+
w2,
178+
w2_bias,
179+
exp_data,
180+
x_tri,
181+
w1_tri,
182+
w2_tri,
183+
exp_data_tri,
184+
w1_bias_tri,
185+
w2_bias_tri,
186+
pc1,
187+
pc2,
188+
)
154189

155190

156191
@dataclass
@@ -184,13 +219,14 @@ def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
184219

185220

186221
def oai_moe_forward(
187-
hidden_states: torch.Tensor, # (M, K)
188-
w1: torch.Tensor, # (E, 2N)
189-
w1_bias: torch.Tensor, # (E, 2N, K)
190-
w2: torch.Tensor, # (E, K, N)
191-
w2_bias: torch.Tensor, # (E, N)
192-
gating_output: torch.Tensor, # (M, E)
193-
topk: int):
222+
hidden_states: torch.Tensor, # (M, K)
223+
w1: torch.Tensor, # (E, 2N)
224+
w1_bias: torch.Tensor, # (E, 2N, K)
225+
w2: torch.Tensor, # (E, K, N)
226+
w2_bias: torch.Tensor, # (E, N)
227+
gating_output: torch.Tensor, # (M, E)
228+
topk: int,
229+
):
194230
# model.py 309:330, assuming gating and norm
195231
t = hidden_states
196232
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
@@ -240,10 +276,22 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
240276
N = ModelConfig.intermediate_size // tp
241277
topk = ModelConfig.experts_per_token
242278

243-
x, w1, w1_bias, w2, w2_bias, exp_data, \
244-
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\
245-
w2_bias_tri, pc1, pc2 = init_compute_data(
246-
M, K, N, E, a_dtype, w_dtype, num_warps=8)
279+
(
280+
x,
281+
w1,
282+
w1_bias,
283+
w2,
284+
w2_bias,
285+
exp_data,
286+
x_tri,
287+
w1_tri,
288+
w2_tri,
289+
exp_data_tri,
290+
w1_bias_tri,
291+
w2_bias_tri,
292+
pc1,
293+
pc2,
294+
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
247295

248296
out_triton_monolithic = triton_kernel_moe_forward(
249297
hidden_states=x_tri,
@@ -255,33 +303,46 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
255303
w1_bias=w1_bias_tri,
256304
w2_bias=w2_bias_tri,
257305
w1_precision=pc1,
258-
w2_precision=pc2)
306+
w2_precision=pc2,
307+
)
259308
out_triton_monolithic = out_triton_monolithic[..., :K]
260309

261-
out_ref = oai_moe_forward(hidden_states=x,
262-
w1=w1,
263-
w1_bias=w1_bias,
264-
w2=w2,
265-
w2_bias=w2_bias,
266-
gating_output=exp_data,
267-
topk=topk)
310+
out_ref = oai_moe_forward(
311+
hidden_states=x,
312+
w1=w1,
313+
w1_bias=w1_bias,
314+
w2=w2,
315+
w2_bias=w2_bias,
316+
gating_output=exp_data,
317+
topk=topk,
318+
)
268319
assert_close(ref=out_ref,
269320
tri=out_triton_monolithic,
270321
maxtol=0.025,
271322
rmstol=0.005)
272323

273324

274-
def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor,
275-
topk: int, renormalize: bool, w1_bias: torch.Tensor,
276-
w2_bias: torch.Tensor, w1_precision: PrecisionConfig,
277-
w2_precision: PrecisionConfig) -> torch.Tensor:
325+
def batched_moe(
326+
a: torch.Tensor,
327+
w1,
328+
w2,
329+
gating_output: torch.Tensor,
330+
topk: int,
331+
renormalize: bool,
332+
w1_bias: torch.Tensor,
333+
w2_bias: torch.Tensor,
334+
w1_precision: PrecisionConfig,
335+
w2_precision: PrecisionConfig,
336+
) -> torch.Tensor:
278337
max_num_tokens = round_up(a.shape[0], 64)
279338

280339
fused_experts = FusedMoEModularKernel(
281-
BatchedPrepareAndFinalize(max_num_tokens,
282-
num_dispatchers=1,
283-
num_local_experts=w1.shape[0],
284-
rank=0),
340+
BatchedPrepareAndFinalize(
341+
max_num_tokens,
342+
num_dispatchers=1,
343+
num_local_experts=w1.shape[0],
344+
rank=0,
345+
),
285346
BatchedOAITritonExperts(
286347
None,
287348
max_num_tokens=max_num_tokens,
@@ -327,30 +388,46 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
327388
N = ModelConfig.intermediate_size
328389
topk = ModelConfig.experts_per_token
329390

330-
x, w1, w1_bias, w2, w2_bias, exp_data, \
331-
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \
332-
w2_bias_tri, pc1, pc2 = init_compute_data(
333-
M, K, N, E, a_dtype, w_dtype, num_warps=4)
334-
335-
out_tri = batched_moe(a=x_tri,
336-
w1=w1_tri,
337-
w2=w2_tri,
338-
gating_output=exp_data_tri,
339-
topk=topk,
340-
renormalize=True,
341-
w1_bias=w1_bias_tri,
342-
w2_bias=w2_bias_tri,
343-
w1_precision=pc1,
344-
w2_precision=pc2)
391+
(
392+
x,
393+
w1,
394+
w1_bias,
395+
w2,
396+
w2_bias,
397+
exp_data,
398+
x_tri,
399+
w1_tri,
400+
w2_tri,
401+
exp_data_tri,
402+
w1_bias_tri,
403+
w2_bias_tri,
404+
pc1,
405+
pc2,
406+
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4)
407+
408+
out_tri = batched_moe(
409+
a=x_tri,
410+
w1=w1_tri,
411+
w2=w2_tri,
412+
gating_output=exp_data_tri,
413+
topk=topk,
414+
renormalize=True,
415+
w1_bias=w1_bias_tri,
416+
w2_bias=w2_bias_tri,
417+
w1_precision=pc1,
418+
w2_precision=pc2,
419+
)
345420
out_tri = out_tri[..., :K]
346421

347-
out_ref = oai_moe_forward(hidden_states=x,
348-
w1=w1,
349-
w1_bias=w1_bias,
350-
w2=w2,
351-
w2_bias=w2_bias,
352-
gating_output=exp_data,
353-
topk=topk)
422+
out_ref = oai_moe_forward(
423+
hidden_states=x,
424+
w1=w1,
425+
w1_bias=w1_bias,
426+
w2=w2,
427+
w2_bias=w2_bias,
428+
gating_output=exp_data,
429+
topk=topk,
430+
)
354431
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
355432

356433

@@ -370,6 +447,7 @@ def test_unit_shuffle():
370447
out = triton_kernels.swiglu.swiglu_torch(
371448
out,
372449
alpha=1.702,
373-
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0))
450+
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0),
451+
)
374452

375-
assert_close(ref=out_ref, tri=out)
453+
assert_close(ref=out_ref, tri=out)

0 commit comments

Comments
 (0)