Skip to content

Commit 4d4061b

Browse files
authored
[Kernel] Add cuda kernel for gpt_oss activation (#22951)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 87f4862 commit 4d4061b

File tree

9 files changed

+157
-42
lines changed

9 files changed

+157
-42
lines changed

csrc/activation_kernels.cu

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
128128
}
129129
}
130130

131+
template <typename T>
132+
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
133+
float alpha, float limit) {
134+
// clamp gate: min=None, max=limit
135+
const float gate_f = (float)gate;
136+
const float clamped_gate = gate_f > limit ? limit : gate_f;
137+
138+
// clamp up: min=-limit, max=limit
139+
const float up_f = (float)up;
140+
const float clamped_up =
141+
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
142+
143+
// glu = gate * sigmoid(gate * alpha)
144+
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
145+
const float glu = clamped_gate * sigmoid_val;
146+
147+
// (up + 1) * glu
148+
return (T)((clamped_up + 1.0f) * glu);
149+
}
150+
151+
template <typename scalar_t,
152+
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
153+
const float)>
154+
__global__ void swigluoai_and_mul_kernel(
155+
scalar_t* __restrict__ out, // [..., d]
156+
const scalar_t* __restrict__ input, // [..., 2, d]
157+
const int d, const float alpha, const float limit) {
158+
const int64_t token_idx = blockIdx.x;
159+
// TODO: Vectorize loads and stores.
160+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
161+
// gate = x[..., ::2] (even indices)
162+
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
163+
// up = x[..., 1::2] (odd indices)
164+
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
165+
166+
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
167+
}
168+
}
169+
131170
} // namespace vllm
132171

133172
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
145184
PARAM); \
146185
});
147186

187+
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
188+
int d = input.size(-1) / 2; \
189+
int64_t num_tokens = input.numel() / input.size(-1); \
190+
dim3 grid(num_tokens); \
191+
dim3 block(std::min(d, 1024)); \
192+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
193+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
194+
VLLM_DISPATCH_FLOATING_TYPES( \
195+
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
196+
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
197+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
198+
input.data_ptr<scalar_t>(), d, ALPHA, \
199+
LIMIT); \
200+
});
201+
148202
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
149203
torch::Tensor& input, // [..., 2 * d]
150204
double threshold) {
151205
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
152206
}
207+
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
208+
torch::Tensor& input, // [..., 2 * d]
209+
double alpha, double limit) {
210+
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
211+
}
153212
namespace vllm {
154213

155214
// Element-wise activation kernel template.

csrc/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
138138

139139
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
140140
double threshold);
141+
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
142+
double alpha = 1.702, double limit = 7.0);
141143

142144
void gelu_new(torch::Tensor& out, torch::Tensor& input);
143145

csrc/torch_bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130130
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
131131
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
132132

133+
ops.def(
134+
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
135+
"limit=7.0) "
136+
"-> ()");
137+
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
138+
133139
// GELU implementation used in GPT-2.
134140
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
135141
ops.impl("gelu_new", torch::kCUDA, &gelu_new);

tests/kernels/core/test_activation.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
1212
GeluAndMul, MulAndSilu,
1313
NewGELU, QuickGELU,
14-
SiluAndMul)
14+
SiluAndMul, SwigluOAIAndMul)
1515
from vllm.platforms import current_platform
1616

1717
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -25,7 +25,15 @@
2525

2626
@pytest.mark.parametrize(
2727
"activation",
28-
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
28+
[
29+
"silu_and_mul",
30+
"mul_and_silu",
31+
"gelu",
32+
"gelu_tanh",
33+
"fatrelu",
34+
"swigluoai_and_mul",
35+
],
36+
)
2937
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
3038
@pytest.mark.parametrize("d", D)
3139
@pytest.mark.parametrize("dtype", DTYPES)
@@ -59,18 +67,43 @@ def test_act_and_mul(
5967
threshold = random.uniform(0, 1)
6068
layer = FatreluAndMul(threshold)
6169
fn = torch.ops._C.fatrelu_and_mul
70+
elif activation == "swigluoai_and_mul":
71+
layer = SwigluOAIAndMul()
72+
fn = torch.ops._C.swigluoai_and_mul
6273
out = layer(x)
6374
ref_out = layer.forward_native(x)
64-
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
65-
# equivalent to the native PyTorch implementations, so we can do exact
66-
# comparison.
67-
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
75+
if activation == "swigluoai_and_mul":
76+
77+
rtol = {
78+
#For fp16, change the relative tolerance from 1e-3 to 2e-3
79+
torch.float16:
80+
2e-3,
81+
torch.bfloat16:
82+
2e-2,
83+
torch.float:
84+
1.3e-6
85+
}
86+
87+
def _get_rtol(output) -> float:
88+
return rtol[output.dtype]
89+
90+
torch.testing.assert_close(out,
91+
ref_out,
92+
atol=get_default_atol(out),
93+
rtol=_get_rtol(out))
94+
else:
95+
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
96+
# equivalent to the native PyTorch implementations, so we can do exact
97+
# comparison.
98+
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
6899

69100
d = x.shape[-1] // 2
70101
output_shape = (x.shape[:-1] + (d, ))
71102
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
72103
if activation == "fatrelu":
73104
opcheck(fn, (out, x, threshold))
105+
elif activation == "swigluoai_and_mul":
106+
opcheck(fn, (out, x, layer.alpha, layer.limit))
74107
else:
75108
opcheck(fn, (out, x))
76109

vllm/model_executor/layers/activation.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,35 @@ def extra_repr(self) -> str:
239239
return f'approximate={repr(self.approximate)}'
240240

241241

242+
@CustomOp.register("swigluoai_and_mul")
243+
class SwigluOAIAndMul(CustomOp):
244+
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
245+
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
246+
super().__init__()
247+
self.alpha = alpha
248+
self.limit = limit
249+
250+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
251+
"""PyTorch-native implementation equivalent to forward()."""
252+
253+
gate, up = x[..., ::2], x[..., 1::2]
254+
gate = gate.clamp(min=None, max=self.limit)
255+
up = up.clamp(min=-self.limit, max=self.limit)
256+
glu = gate * torch.sigmoid(gate * self.alpha)
257+
gated_output = (up + 1) * glu
258+
return gated_output
259+
260+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
261+
d = x.shape[-1] // 2
262+
output_shape = (x.shape[:-1] + (d, ))
263+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
264+
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
265+
return out
266+
267+
def extra_repr(self) -> str:
268+
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
269+
270+
242271
@CustomOp.register("gelu_new")
243272
class NewGELU(CustomOp):
244273

@@ -330,6 +359,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
330359
return torch.square(F.relu(x))
331360

332361
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
362+
#TODO : implement cuda kenrels
333363
return self.forward_native(x)
334364

335365

@@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
406436

407437

408438
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
409-
"gelu": lambda: GeluAndMul(),
410-
"silu": lambda: SiluAndMul(),
411-
"geglu": lambda: GeluAndMul(),
439+
"gelu":
440+
lambda: GeluAndMul(),
441+
"silu":
442+
lambda: SiluAndMul(),
443+
"geglu":
444+
lambda: GeluAndMul(),
445+
"swigluoai":
446+
lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
412447
})
413448

414449

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,25 +161,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
161161
if activation == "silu":
162162
torch.ops._C.silu_and_mul(intermediate_cache2,
163163
intermediate_cache1.view(-1, 2 * N))
164-
elif activation == "swiglu_oai":
165-
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
166-
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
167-
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
168-
169-
@torch.compile(dynamic=True)
170-
def swiglu_oai(gate_up):
171-
alpha = 1.702
172-
limit = 7.0
173-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
174-
gate = gate.clamp(min=None, max=limit)
175-
up = up.clamp(min=-limit, max=limit)
176-
glu = gate * torch.sigmoid(gate * alpha)
177-
return (up + 1) * glu
178-
179-
intermediate_cache2 = swiglu_oai(intermediate_cache1)
164+
elif activation == "swigluoai":
165+
# alpha = 1.702, limit = 7.0
166+
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
167+
intermediate_cache1.view(-1, 2 * N))
180168
else:
181169
raise ValueError(f"Unsupported activation: {activation}. "
182-
"Only silu and swiglu_oai activations are supported.")
170+
"Only silu and swigluoai activations are supported.")
183171

184172
if expert_map is not None:
185173
intermediate_cache3.zero_()

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,31 +1621,23 @@ def fused_experts_impl(
16211621
block_shape=block_shape,
16221622
B_bias=w1_bias)
16231623

1624-
# TODO fused kernel
1625-
def swiglu_oai(gate_up):
1626-
alpha = 1.702
1627-
limit = 7.0
1628-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
1629-
gate = gate.clamp(min=None, max=limit)
1630-
up = up.clamp(min=-limit, max=limit)
1631-
glu = gate * torch.sigmoid(gate * alpha)
1632-
gated_output = (up + 1) * glu
1633-
return gated_output
1634-
16351624
# Activation function with multiplication
16361625
if activation == "silu" and is_act_and_mul:
16371626
torch.ops._C.silu_and_mul(intermediate_cache2,
16381627
intermediate_cache1.view(-1, N))
16391628
elif activation == "gelu" and is_act_and_mul:
16401629
torch.ops._C.gelu_and_mul(intermediate_cache2,
16411630
intermediate_cache1.view(-1, N))
1631+
elif activation == "swigluoai" and is_act_and_mul:
1632+
# alpha = 1.702, limit = 7.0
1633+
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
1634+
intermediate_cache1.view(-1, N))
16421635
# Activation function without multiplication
16431636
elif activation == "silu":
16441637
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
16451638
elif activation == "gelu":
16461639
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
1647-
elif activation == "swiglu_oai":
1648-
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
1640+
16491641
else:
16501642
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
16511643
f"with is_act_and_mul={is_act_and_mul}.")

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
6161
e_score_correction_bias: Optional[torch.Tensor] = None,
6262
apply_router_weight_on_input: bool = False,
6363
scoring_func: str = "softmax",
64-
activation: str = "swiglu_oai",
64+
activation: str = "swigluoai",
6565
expert_load_view: Optional[torch.Tensor] = None,
6666
logical_to_physical_map: Optional[torch.Tensor] = None,
6767
logical_replica_count: Optional[torch.Tensor] = None):
6868
return not (use_grouped_topk or topk_group or num_expert_group
6969
or expert_map or custom_routing_function
7070
or e_score_correction_bias or apply_router_weight_on_input
71-
or scoring_func != "softmax" or activation != "swiglu_oai"
71+
or scoring_func != "softmax" or activation != "swigluoai"
7272
or expert_load_view or logical_to_physical_map
7373
or logical_replica_count)
7474

vllm/model_executor/models/gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
prefix=f"{prefix}.experts",
160160
apply_router_weight_on_input=False,
161161
has_bias=True,
162-
activation="swiglu_oai")
162+
activation="swigluoai")
163163

164164
def forward(self, x: torch.Tensor) -> torch.Tensor:
165165
t = self.norm(x)

0 commit comments

Comments
 (0)