diff --git a/python/triton_kernels/tests/test_swiglu.py b/python/triton_kernels/tests/test_swiglu.py index ad5f3e931ec9..b353c727a3a8 100644 --- a/python/triton_kernels/tests/test_swiglu.py +++ b/python/triton_kernels/tests/test_swiglu.py @@ -1,4 +1,4 @@ -from triton_kernels.swiglu import swiglu, swiglu_torch, PrecisionConfig +from triton_kernels.swiglu import swiglu, swiglu_torch, standard_swiglu, standard_swiglu_torch, PrecisionConfig from triton_kernels.testing import assert_close import torch import pytest @@ -22,11 +22,23 @@ def alloc_rand(shape, device, dtype, requires_grad=True): @pytest.mark.parametrize("M, N", [(1311, 4352)]) @pytest.mark.parametrize("limit", [1e-2, 10]) -def test_op(M, N, limit, device, alpha=0.5): +@pytest.mark.parametrize("add_bias", [True, False]) +def test_op(M, N, limit, add_bias, device, alpha=0.5): torch.manual_seed(2) # initialize data x = alloc_rand([M, N], device=device, dtype=torch.bfloat16) precision_config = PrecisionConfig(limit=limit) - tri_y = swiglu(x, alpha, precision_config) - ref_y = swiglu_torch(x, alpha, precision_config) + tri_y = swiglu(x, alpha, precision_config, add_bias=add_bias) + ref_y = swiglu_torch(x, alpha, precision_config, add_bias=add_bias) + assert_close(tri_y, ref_y) + + +@pytest.mark.parametrize("M, N", [(1311, 4352)]) +def test_op_standard_swiglu(M, N, device): + torch.manual_seed(2) + # initialize data + x = alloc_rand([M, N], device=device, dtype=torch.bfloat16) + precision_config = PrecisionConfig(limit=None) + tri_y = standard_swiglu(x, precision_config) + ref_y = standard_swiglu_torch(x) assert_close(tri_y, ref_y) diff --git a/python/triton_kernels/triton_kernels/swiglu.py b/python/triton_kernels/triton_kernels/swiglu.py index 3f26427873d4..bb4123685a25 100644 --- a/python/triton_kernels/triton_kernels/swiglu.py +++ b/python/triton_kernels/triton_kernels/swiglu.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass +from dataclasses import dataclass, replace from triton_kernels.numerics import InFlexData, OutFlexData import torch import triton -from .swiglu_details._swiglu import _swiglu, _swiglu_fn +from .swiglu_details._swiglu import _swiglu, _swiglu_fn, _standard_swiglu_fn from triton_kernels import target_info @@ -20,12 +20,13 @@ class PrecisionConfig: swiglu_fn = _swiglu_fn +standard_swiglu_fn = _standard_swiglu_fn class SwiGLU(torch.autograd.Function): @staticmethod - def forward(ctx, a, alpha, precision_config, routing_data): + def forward(ctx, a, alpha, precision_config, routing_data, add_bias=True): N = a.shape[-1] M = a.numel() // N assert a.stride()[-1] == 1 @@ -68,6 +69,7 @@ def forward(ctx, a, alpha, precision_config, routing_data): out.shape[-1], 1, precision_config.limit, + add_bias, n_tokens, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -82,11 +84,16 @@ def forward(ctx, a, alpha, precision_config, routing_data): return out -def swiglu(a, alpha, precision_config, routing_data=None): - return SwiGLU.apply(a, alpha, precision_config, routing_data) +def swiglu(a, alpha, precision_config, routing_data=None, add_bias=True): + return SwiGLU.apply(a, alpha, precision_config, routing_data, add_bias) -def swiglu_torch(a, alpha, precision_config): +def standard_swiglu(a, precision_config, routing_data=None): + pc = replace(precision_config, limit=None) + return SwiGLU.apply(a, 1.0, pc, routing_data, False) + + +def swiglu_torch(a, alpha, precision_config, add_bias=True): limit = precision_config.limit a_gelu = a[..., ::2] if limit is not None: @@ -96,5 +103,16 @@ def swiglu_torch(a, alpha, precision_config): a_linear = a_linear.clamp(min=-limit, max=limit) out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) - out = out_gelu * (a_linear + 1) + if add_bias: + out = out_gelu * (a_linear + 1) + else: + out = out_gelu * a_linear + return out + + +def standard_swiglu_torch(a): + a_gelu = a[..., ::2] + a_linear = a[..., 1::2] + out_gelu = a_gelu * torch.sigmoid(a_gelu) + out = out_gelu * a_linear return out diff --git a/python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py b/python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py index fbcea076f7dc..789dc15514ea 100644 --- a/python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py +++ b/python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py @@ -35,7 +35,7 @@ def swiglu_launch_metadata(grid, kernel, args): @triton.jit -def compute_swiglu(gelu, linear, scale, alpha, limit): +def compute_swiglu(gelu, linear, scale, alpha, limit, add_bias): gelu = gelu.to(tl.float32) * scale if limit is not None: gelu = clip(gelu, limit, clip_lower=False) @@ -43,19 +43,28 @@ def compute_swiglu(gelu, linear, scale, alpha, limit): if limit is not None: linear = clip(linear, limit, clip_lower=True) s = gelu / (1 + tl.exp(-alpha * gelu)) - return tl.fma(s, linear, s) # (s * (linear + 1)) + if add_bias: + return tl.fma(s, linear, s) # (s * (linear + 1)) + else: + return s * linear @triton.jit(repr=lambda _: "_swiglu") -def _swiglu_fn(input, alpha, limit): +def _swiglu_fn(input, alpha, limit, add_bias=True): gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) - return compute_swiglu(gelu, linear, 1.0, alpha, limit) + return compute_swiglu(gelu, linear, 1.0, alpha, limit, add_bias) + + +@triton.jit(repr=lambda _: "_standard_swiglu") +def _standard_swiglu_fn(input): + gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) + return compute_swiglu(gelu, linear, 1.0, 1.0, None, False) @triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata) def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an, - stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr): + stride_outm, stride_outn, limit: tl.constexpr, add_bias: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr): if NTokens is not None: M = tl.load(NTokens) M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M @@ -87,7 +96,7 @@ def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, packed_mask = mask_m[:, None] & packed_mask_n[None, :] a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.) a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2))) - out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit) + out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit, add_bias) # update flexpoint stats and divide by scale # we don't need masking because of the `other` when loading `A` if OutActualScale is not None: