Skip to content

Commit e789cad

Browse files
authored
[gpt-oss] triton kernel mxfp4 (#22421)
Signed-off-by: <[email protected]> Signed-off-by: Yongye Zhu <[email protected]>
1 parent e5ebeeb commit e789cad

File tree

8 files changed

+755
-9
lines changed

8 files changed

+755
-9
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# vllm-flash-attn built from source
55
vllm/vllm_flash_attn/*
66

7+
# triton jit
8+
.triton
9+
710
# Byte-compiled / optimized / DLL files
811
__pycache__/
912
*.py[cod]
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import dataclass, fields
4+
5+
import pytest
6+
import torch
7+
import torch.nn.functional as F
8+
import triton_kernels.swiglu
9+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
10+
from triton_kernels.numerics import InFlexData
11+
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
12+
upcast_from_mxfp)
13+
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
14+
from triton_kernels.tensor_details import layout
15+
from triton_kernels.testing import assert_close
16+
17+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
18+
BatchedPrepareAndFinalize)
19+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
20+
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
21+
BatchedOAITritonExperts, triton_kernel_moe_forward)
22+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
23+
FusedMoEModularKernel)
24+
from vllm.model_executor.layers.utils import shuffle_weight
25+
from vllm.utils import round_up
26+
27+
28+
def deshuffle(w: torch.Tensor):
29+
first = w[..., ::2]
30+
second = w[..., 1::2]
31+
32+
deshuffled = torch.concat((first, second), dim=-1)
33+
return deshuffled
34+
35+
36+
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
37+
randbits = [torch.randperm(E) for _ in range(M)]
38+
x_list = [
39+
(-1)**i *
40+
((16384 +
41+
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
42+
for i, bits in enumerate(randbits)
43+
]
44+
exp_data = torch.stack(x_list).to(
45+
device="cuda") # simulating gate_output (M, E)
46+
47+
# create input tensor
48+
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
49+
w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
50+
w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")
51+
52+
w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
53+
w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")
54+
55+
exp_data_tri = exp_data.clone()
56+
x_tri = x.clone()
57+
w1_tri = w1.clone()
58+
w2_tri = w2.clone()
59+
60+
w1_bias_tri = w1_bias.clone()
61+
w2_bias_tri = w2_bias.clone()
62+
w1_bias_tri = w1_bias_tri.to(torch.float32)
63+
w2_bias_tri = w2_bias_tri.to(torch.float32)
64+
65+
dtype_dict = {
66+
"bf16": torch.bfloat16,
67+
"fp8_e4m3": torch.float8_e4m3fn,
68+
"fp8_e5m2": torch.float8_e5m2
69+
}
70+
71+
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
72+
if w_dtype != "mx4":
73+
# simulate quantization support on reference impl
74+
w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
75+
w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)
76+
77+
# triton moe kernel use transposed shape for matmul
78+
w1_tri = w1_tri.transpose(-2, -1)
79+
w2_tri = w2_tri.transpose(-2, -1)
80+
81+
# shuffle weights
82+
w1_tri = shuffle_weight(w1_tri)
83+
w1_bias_tri = shuffle_weight(w1_bias_tri)
84+
85+
# quant triton_weights
86+
x_tri = x.to(dtype_dict[a_dtype])
87+
if w_dtype != "mx4":
88+
pytest.skip("NYI")
89+
else: # quantize to mx4
90+
# careful on the padding here, the activation padding need to be
91+
# multiple of 64, the actual engine is not implemented
92+
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
93+
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
94+
95+
w2_bottom_pad = w1_right_pad // 2
96+
w2_right_pad = w1_bottom_pad
97+
98+
x_pad = w1_bottom_pad
99+
100+
w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
101+
mode="constant",
102+
value=0)
103+
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
104+
mode="constant",
105+
value=0)
106+
107+
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
108+
mode="constant",
109+
value=0)
110+
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
111+
mode="constant",
112+
value=0)
113+
114+
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
115+
116+
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
117+
mx_axis=1)
118+
w_scale_layout, w_scale_layout_opts = (
119+
layout.make_default_matmul_mxfp4_w_scale_layout(
120+
mx_axis=1, num_warps=num_warps))
121+
122+
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
123+
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
124+
125+
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
126+
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
127+
128+
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
129+
**w_layout_opts)
130+
w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri),
131+
w_scale_layout, **w_scale_layout_opts)
132+
133+
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
134+
**w_layout_opts)
135+
w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri),
136+
w_scale_layout, **w_scale_layout_opts)
137+
138+
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
139+
flex_ctx=FlexCtx(rhs_data=InFlexData()))
140+
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
141+
flex_ctx=FlexCtx(rhs_data=InFlexData()))
142+
143+
# tucuate so the rest can run properly
144+
w1 = w1[..., :K, :2 * N]
145+
w2 = w2[..., :N, :K]
146+
147+
w1 = deshuffle(w1)
148+
149+
w1 = w1.transpose(-1, -2).contiguous()
150+
w2 = w2.transpose(-1, -2).contiguous()
151+
152+
return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri,
153+
exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2)
154+
155+
156+
@dataclass
157+
class ModelConfig:
158+
num_hidden_layers: int = 36
159+
num_experts: int = 128
160+
experts_per_token: int = 4
161+
vocab_size: int = 201088
162+
hidden_size: int = 2880
163+
intermediate_size: int = 2880
164+
head_dim: int = 64
165+
num_attention_heads: int = 64
166+
num_key_value_heads: int = 8
167+
sliding_window: int = 128
168+
initial_context_length: int = 4096
169+
rope_theta: float = 150000.0
170+
rope_scaling_factor: float = 32.0
171+
rope_ntk_alpha: float = 1.0
172+
rope_ntk_beta: float = 32.0
173+
174+
175+
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
176+
# Note we add an extra bias of 1 to the linear layer
177+
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
178+
if limit is not None:
179+
x_glu = x_glu.clamp(max=limit)
180+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
181+
if limit is not None:
182+
x_linear = x_linear.clamp(min=-limit, max=limit)
183+
return out_glu * (x_linear + 1)
184+
185+
186+
def oai_moe_forward(
187+
hidden_states: torch.Tensor, # (M, K)
188+
w1: torch.Tensor, # (E, 2N)
189+
w1_bias: torch.Tensor, # (E, 2N, K)
190+
w2: torch.Tensor, # (E, K, N)
191+
w2_bias: torch.Tensor, # (E, N)
192+
gating_output: torch.Tensor, # (M, E)
193+
topk: int):
194+
# model.py 309:330, assuming gating and norm
195+
t = hidden_states
196+
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
197+
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
198+
expert_indices = experts.indices
199+
200+
# MLP #1
201+
mlp1_weight = w1[expert_indices, ...]
202+
mlp1_bias = w1_bias[expert_indices, ...]
203+
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
204+
t = swiglu(t, limit=7)
205+
206+
# MLP #2
207+
mlp2_weight = w2[expert_indices, ...]
208+
mlp2_bias = w2_bias[expert_indices, ...]
209+
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
210+
t += mlp2_bias
211+
212+
# Weighted sum of experts
213+
t = torch.einsum("bec,be->bc", t, expert_weights)
214+
215+
return t
216+
217+
218+
@dataclass
219+
class Case:
220+
a_dtype: str
221+
w_dtype: str
222+
223+
224+
@pytest.mark.parametrize(
225+
", ".join(f.name for f in fields(Case)),
226+
[
227+
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
228+
# Case(a_dtype="bf16", w_dtype="bf16"),
229+
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
230+
Case(a_dtype="bf16", w_dtype="mx4")
231+
]
232+
],
233+
)
234+
@pytest.mark.parametrize("num_token", [2])
235+
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
236+
def test_equiv(num_token, a_dtype, w_dtype, tp):
237+
M = num_token
238+
E = ModelConfig.num_experts
239+
K = ModelConfig.hidden_size
240+
N = ModelConfig.intermediate_size // tp
241+
topk = ModelConfig.experts_per_token
242+
243+
x, w1, w1_bias, w2, w2_bias, exp_data, \
244+
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\
245+
w2_bias_tri, pc1, pc2 = init_compute_data(
246+
M, K, N, E, a_dtype, w_dtype, num_warps=8)
247+
248+
out_triton_monolithic = triton_kernel_moe_forward(
249+
hidden_states=x_tri,
250+
w1=w1_tri,
251+
w2=w2_tri,
252+
gating_output=exp_data_tri,
253+
topk=topk,
254+
renormalize=True,
255+
w1_bias=w1_bias_tri,
256+
w2_bias=w2_bias_tri,
257+
w1_precision=pc1,
258+
w2_precision=pc2)
259+
out_triton_monolithic = out_triton_monolithic[..., :K]
260+
261+
out_ref = oai_moe_forward(hidden_states=x,
262+
w1=w1,
263+
w1_bias=w1_bias,
264+
w2=w2,
265+
w2_bias=w2_bias,
266+
gating_output=exp_data,
267+
topk=topk)
268+
assert_close(ref=out_ref,
269+
tri=out_triton_monolithic,
270+
maxtol=0.025,
271+
rmstol=0.005)
272+
273+
274+
def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor,
275+
topk: int, renormalize: bool, w1_bias: torch.Tensor,
276+
w2_bias: torch.Tensor, w1_precision: PrecisionConfig,
277+
w2_precision: PrecisionConfig) -> torch.Tensor:
278+
max_num_tokens = round_up(a.shape[0], 64)
279+
280+
fused_experts = FusedMoEModularKernel(
281+
BatchedPrepareAndFinalize(max_num_tokens,
282+
num_dispatchers=1,
283+
num_local_experts=w1.shape[0],
284+
rank=0),
285+
BatchedOAITritonExperts(
286+
None,
287+
max_num_tokens=max_num_tokens,
288+
num_dispatchers=1,
289+
w1_precision=w1_precision,
290+
w2_precision=w2_precision,
291+
),
292+
)
293+
294+
extra_expert_args = {
295+
"w1_bias": w1_bias,
296+
"w2_bias": w2_bias,
297+
}
298+
299+
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
300+
301+
return fused_experts(
302+
a,
303+
w1,
304+
w2,
305+
topk_weight,
306+
topk_ids,
307+
extra_expert_args=extra_expert_args,
308+
)
309+
310+
311+
@pytest.mark.parametrize(
312+
", ".join(f.name for f in fields(Case)),
313+
[
314+
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
315+
# Case(a_dtype="bf16", w_dtype="bf16"),
316+
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
317+
Case(a_dtype="bf16", w_dtype="mx4")
318+
]
319+
],
320+
)
321+
@pytest.mark.parametrize("num_token", [64])
322+
@pytest.mark.parametrize("ep", [1, 2, 4, 8])
323+
def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
324+
M = num_token
325+
E = ModelConfig.num_experts // ep
326+
K = ModelConfig.hidden_size
327+
N = ModelConfig.intermediate_size
328+
topk = ModelConfig.experts_per_token
329+
330+
x, w1, w1_bias, w2, w2_bias, exp_data, \
331+
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \
332+
w2_bias_tri, pc1, pc2 = init_compute_data(
333+
M, K, N, E, a_dtype, w_dtype, num_warps=4)
334+
335+
out_tri = batched_moe(a=x_tri,
336+
w1=w1_tri,
337+
w2=w2_tri,
338+
gating_output=exp_data_tri,
339+
topk=topk,
340+
renormalize=True,
341+
w1_bias=w1_bias_tri,
342+
w2_bias=w2_bias_tri,
343+
w1_precision=pc1,
344+
w2_precision=pc2)
345+
out_tri = out_tri[..., :K]
346+
347+
out_ref = oai_moe_forward(hidden_states=x,
348+
w1=w1,
349+
w1_bias=w1_bias,
350+
w2=w2,
351+
w2_bias=w2_bias,
352+
gating_output=exp_data,
353+
topk=topk)
354+
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
355+
356+
357+
def test_unit_shuffle():
358+
N = ModelConfig.intermediate_size
359+
K = ModelConfig.hidden_size
360+
m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")
361+
362+
x = torch.randn(K, dtype=torch.bfloat16, device="cuda")
363+
364+
m_shuffled = shuffle_weight(m)
365+
366+
out_ref = x @ m
367+
out_ref = swiglu(out_ref, limit=1.0)
368+
369+
out = x @ m_shuffled
370+
out = triton_kernels.swiglu.swiglu_torch(
371+
out,
372+
alpha=1.702,
373+
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0))
374+
375+
assert_close(ref=out_ref, tri=out)

0 commit comments

Comments
 (0)