Commit cb96f11
authored
[KERNELS] Save some instructions in swiglu (#8801)
Say `gelu` has n registers per thread. Currently, `exp(-alpha * gelu)`
takes 1 `sub.f32` and 2n `mul.f32` instructions since `exp(x)` gets
expanded to `exp2(log2(e) * x)` by the time we get to ptx. We can
rewrite this as scaling `gelu` by the scalar value `(-alpha * log2(e))`
which is just 1 + n `mul.f32` instructions.
I also use `ex2.approx.ftz.f32` which is a single `MUFU.EX2` in SASS,
compared to the non-ftz variant which requires multiple SASS
instructions to handle denormal values. This is fine numerically since
we add 1 to the result anyway which will round out anything below
epsilon.1 parent 0f235ee commit cb96f11
File tree
1 file changed
+20
-1
lines changed- python/triton_kernels/triton_kernels/swiglu_details
1 file changed
+20
-1
lines changedLines changed: 20 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
38 | 53 | | |
39 | 54 | | |
40 | 55 | | |
| |||
43 | 58 | | |
44 | 59 | | |
45 | 60 | | |
46 | | - | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
47 | 66 | | |
48 | 67 | | |
49 | 68 | | |
| |||
0 commit comments