Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions python/triton_kernels/tests/test_swiglu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from triton_kernels.swiglu import swiglu, swiglu_torch, PrecisionConfig
from triton_kernels.swiglu import swiglu, swiglu_torch, standard_swiglu, standard_swiglu_torch, PrecisionConfig
from triton_kernels.testing import assert_close
import torch
import pytest
Expand All @@ -22,11 +22,23 @@ def alloc_rand(shape, device, dtype, requires_grad=True):

@pytest.mark.parametrize("M, N", [(1311, 4352)])
@pytest.mark.parametrize("limit", [1e-2, 10])
def test_op(M, N, limit, device, alpha=0.5):
@pytest.mark.parametrize("add_bias", [True, False])
def test_op(M, N, limit, add_bias, device, alpha=0.5):
torch.manual_seed(2)
# initialize data
x = alloc_rand([M, N], device=device, dtype=torch.bfloat16)
precision_config = PrecisionConfig(limit=limit)
tri_y = swiglu(x, alpha, precision_config)
ref_y = swiglu_torch(x, alpha, precision_config)
tri_y = swiglu(x, alpha, precision_config, add_bias=add_bias)
ref_y = swiglu_torch(x, alpha, precision_config, add_bias=add_bias)
assert_close(tri_y, ref_y)


@pytest.mark.parametrize("M, N", [(1311, 4352)])
def test_op_standard_swiglu(M, N, device):
torch.manual_seed(2)
# initialize data
x = alloc_rand([M, N], device=device, dtype=torch.bfloat16)
precision_config = PrecisionConfig(limit=None)
tri_y = standard_swiglu(x, precision_config)
ref_y = standard_swiglu_torch(x)
assert_close(tri_y, ref_y)
32 changes: 25 additions & 7 deletions python/triton_kernels/triton_kernels/swiglu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass
from dataclasses import dataclass, replace
from triton_kernels.numerics import InFlexData, OutFlexData
import torch
import triton
from .swiglu_details._swiglu import _swiglu, _swiglu_fn
from .swiglu_details._swiglu import _swiglu, _swiglu_fn, _standard_swiglu_fn
from triton_kernels import target_info


Expand All @@ -20,12 +20,13 @@ class PrecisionConfig:


swiglu_fn = _swiglu_fn
standard_swiglu_fn = _standard_swiglu_fn


class SwiGLU(torch.autograd.Function):

@staticmethod
def forward(ctx, a, alpha, precision_config, routing_data):
def forward(ctx, a, alpha, precision_config, routing_data, add_bias=True):
N = a.shape[-1]
M = a.numel() // N
assert a.stride()[-1] == 1
Expand Down Expand Up @@ -68,6 +69,7 @@ def forward(ctx, a, alpha, precision_config, routing_data):
out.shape[-1],
1,
precision_config.limit,
add_bias,
n_tokens,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
Expand All @@ -82,11 +84,16 @@ def forward(ctx, a, alpha, precision_config, routing_data):
return out


def swiglu(a, alpha, precision_config, routing_data=None):
return SwiGLU.apply(a, alpha, precision_config, routing_data)
def swiglu(a, alpha, precision_config, routing_data=None, add_bias=True):
return SwiGLU.apply(a, alpha, precision_config, routing_data, add_bias)


def swiglu_torch(a, alpha, precision_config):
def standard_swiglu(a, precision_config, routing_data=None):
pc = replace(precision_config, limit=None)
return SwiGLU.apply(a, 1.0, pc, routing_data, False)


def swiglu_torch(a, alpha, precision_config, add_bias=True):
limit = precision_config.limit
a_gelu = a[..., ::2]
if limit is not None:
Expand All @@ -96,5 +103,16 @@ def swiglu_torch(a, alpha, precision_config):
a_linear = a_linear.clamp(min=-limit, max=limit)

out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
out = out_gelu * (a_linear + 1)
if add_bias:
out = out_gelu * (a_linear + 1)
else:
out = out_gelu * a_linear
return out


def standard_swiglu_torch(a):
a_gelu = a[..., ::2]
a_linear = a[..., 1::2]
out_gelu = a_gelu * torch.sigmoid(a_gelu)
out = out_gelu * a_linear
return out
23 changes: 16 additions & 7 deletions python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,36 @@ def swiglu_launch_metadata(grid, kernel, args):


@triton.jit
def compute_swiglu(gelu, linear, scale, alpha, limit):
def compute_swiglu(gelu, linear, scale, alpha, limit, add_bias):
gelu = gelu.to(tl.float32) * scale
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32) * scale
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp(-alpha * gelu))
return tl.fma(s, linear, s) # (s * (linear + 1))
if add_bias:
return tl.fma(s, linear, s) # (s * (linear + 1))
else:
return s * linear


@triton.jit(repr=lambda _: "_swiglu")
def _swiglu_fn(input, alpha, limit):
def _swiglu_fn(input, alpha, limit, add_bias=True):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
return compute_swiglu(gelu, linear, 1.0, alpha, limit)
return compute_swiglu(gelu, linear, 1.0, alpha, limit, add_bias)


@triton.jit(repr=lambda _: "_standard_swiglu")
def _standard_swiglu_fn(input):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
return compute_swiglu(gelu, linear, 1.0, 1.0, None, False)


@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an,
stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
stride_outm, stride_outn, limit: tl.constexpr, add_bias: tl.constexpr, NTokens, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
if NTokens is not None:
M = tl.load(NTokens)
M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
Expand Down Expand Up @@ -87,7 +96,7 @@ def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale,
packed_mask = mask_m[:, None] & packed_mask_n[None, :]
a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.)
a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit, add_bias)
# update flexpoint stats and divide by scale
# we don't need masking because of the `other` when loading `A`
if OutActualScale is not None:
Expand Down