Skip to content

Commit cb96f11

Browse files
authored
[KERNELS] Save some instructions in swiglu (#8801)
Say `gelu` has n registers per thread. Currently, `exp(-alpha * gelu)` takes 1 `sub.f32` and 2n `mul.f32` instructions since `exp(x)` gets expanded to `exp2(log2(e) * x)` by the time we get to ptx. We can rewrite this as scaling `gelu` by the scalar value `(-alpha * log2(e))` which is just 1 + n `mul.f32` instructions. I also use `ex2.approx.ftz.f32` which is a single `MUFU.EX2` in SASS, compared to the non-ftz variant which requires multiple SASS instructions to handle denormal values. This is fine numerically since we add 1 to the result anyway which will round out anything below epsilon.
1 parent 0f235ee commit cb96f11

File tree

1 file changed

+20
-1
lines changed
  • python/triton_kernels/triton_kernels/swiglu_details

1 file changed

+20
-1
lines changed

python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def swiglu_launch_metadata(grid, kernel, args):
3535
return ret
3636

3737

38+
@triton.jit
39+
def exp2_ftz(x):
40+
if tl.target_info.is_cuda():
41+
return tl.inline_asm_elementwise(
42+
"ex2.approx.ftz.f32 $0, $1;",
43+
"=r, r",
44+
[x],
45+
dtype=tl.float32,
46+
is_pure=True,
47+
pack=1,
48+
)
49+
else:
50+
return tl.exp2(x)
51+
52+
3853
@triton.jit
3954
def compute_swiglu(gelu, linear, scale, alpha, limit):
4055
gelu = gelu.to(tl.float32) * scale
@@ -43,7 +58,11 @@ def compute_swiglu(gelu, linear, scale, alpha, limit):
4358
linear = linear.to(tl.float32) * scale
4459
if limit is not None:
4560
linear = clip(linear, limit, clip_lower=True)
46-
s = gelu / (1 + tl.exp(-alpha * gelu))
61+
62+
# exp(x) becomes exp2(log2(e) * x) in ptx. By expanding it early, we can factor
63+
# (-alpha * log2_e) into a single scalar factor.
64+
log2_e: tl.constexpr = 1.4426950408889634
65+
s = gelu / (1 + exp2_ftz((-alpha * log2_e) * gelu))
4766
return tl.fma(s, linear, s) # (s * (linear + 1))
4867

4968

0 commit comments

Comments
 (0)