From 65ae8eceef192b8a47e648ae3d1986839b88f28c Mon Sep 17 00:00:00 2001 From: bhack Date: Wed, 18 Dec 2024 15:17:08 +0000 Subject: [PATCH 1/5] Add selective_scan compilable/exportable custom_ops --- .../selective_scan_interface_ compilable.py | 272 ++++++++++++++++++ tests/ops/test_selective_scan.py | 120 +------- 2 files changed, 286 insertions(+), 106 deletions(-) create mode 100644 mamba_ssm/ops/selective_scan_interface_ compilable.py diff --git a/mamba_ssm/ops/selective_scan_interface_ compilable.py b/mamba_ssm/ops/selective_scan_interface_ compilable.py new file mode 100644 index 00000000..60a59b7e --- /dev/null +++ b/mamba_ssm/ops/selective_scan_interface_ compilable.py @@ -0,0 +1,272 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional, Tuple + +import selective_scan_cuda + + +@torch.library.custom_op( + "custom_ops::selective_scan_fwd", + device_types=["cuda"], + mutates_args=(), +) +def custom_selective_scan_fwd( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: Optional[torch.Tensor], + z: Optional[torch.Tensor], + delta_bias: Optional[torch.Tensor], + delta_softplus: bool, + return_last_state: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool, bool]: + pass + +@torch.library.register_fake("custom_ops::selective_scan_fwd") +def custom_selective_scan_fwd_fake( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + return_last_state, +): + final_out = torch.empty_like(u) + dstate = A.size(1) * (2 if A.is_complex() else 1) + last_state_fake = u.new_empty((u.size(0), u.size(1), dstate)) if return_last_state else u.new_empty(0) + out_fake = torch.empty_like(u) + x_fake = u.new_empty((u.size(0), u.size(1), u.size(2), 2 * dstate)) + return final_out, last_state_fake, out_fake, x_fake, False, False, z is not None + +@torch.library.register_kernel("custom_ops::selective_scan_fwd", "cuda") +def custom_selective_scan_fwd_cuda( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: Optional[torch.Tensor], + z: Optional[torch.Tensor], + delta_bias: Optional[torch.Tensor], + delta_softplus: bool, + return_last_state: bool, +): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + + squeeze_B = False + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l").contiguous() + squeeze_B = True + + squeeze_C = False + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l").contiguous() + squeeze_C = True + + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + has_z = z is not None + final_out = rest[0].clone() if has_z else out.clone() + last_state = x[:, :, -1, 1::2].clone() if return_last_state else u.new_empty(0) + return final_out, last_state, out, x, squeeze_B, squeeze_C, has_z + +@torch.library.custom_op( + "custom_ops::selective_scan_bwd", + device_types=["cuda"], + mutates_args=(), +) +def custom_selective_scan_bwd( + dout: torch.Tensor, + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: Optional[torch.Tensor], + z: Optional[torch.Tensor], + delta_bias: Optional[torch.Tensor], + delta_softplus: bool, + out: torch.Tensor, + x: torch.Tensor, + squeeze_B: bool, + squeeze_C: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pass + +@torch.library.register_fake("custom_ops::selective_scan_bwd") +def custom_selective_scan_bwd_fake( + dout, + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + out, + x, + squeeze_B, + squeeze_C, +): + du = torch.empty_like(u) + ddelta = torch.empty_like(delta) + dA = torch.empty_like(A) + dB = torch.empty_like(B) + dC = torch.empty_like(C) + dD = torch.empty_like(D) if (D is not None and D.numel() > 0) else u.new_empty(0) + dz = torch.empty_like(z) if (z is not None and z.numel() > 0) else u.new_empty(0) + ddelta_bias = torch.empty_like(delta_bias) if (delta_bias is not None and delta_bias.numel() > 0) else u.new_empty(0) + return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias + +@torch.library.register_kernel("custom_ops::selective_scan_bwd", "cuda") +def custom_selective_scan_bwd_cuda( + dout: torch.Tensor, + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: Optional[torch.Tensor], + z: Optional[torch.Tensor], + delta_bias: Optional[torch.Tensor], + delta_softplus: bool, + out: torch.Tensor, + x: torch.Tensor, + squeeze_B: bool, + squeeze_C: bool, +): + if dout.stride(-1) != 1: + dout = dout.contiguous() + B = B.contiguous() + C = C.contiguous() + + results = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, delta_softplus, False + ) + has_z = z is not None + if has_z: + du, ddelta, dA, dB, dC, dD, ddelta_bias, dz = results + else: + du, ddelta, dA, dB, dC, dD, ddelta_bias = results + dz = u.new_empty(0) + + if squeeze_B and dB.numel() > 0: + dB = dB.squeeze(1) + if squeeze_C and dC.numel() > 0: + dC = dC.squeeze(1) + + return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias + +def custom_bridge(ctx, *grads): + dout = grads[0] if grads else ctx.saved_tensors[0].new_empty(0) + saved = ctx.saved_tensors + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x, out = saved + z = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = saved + + du, ddelta, dA, dB, dC, dD, dz, ddelta_bias = torch.ops.custom_ops.selective_scan_bwd( + dout, + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + ctx.delta_softplus, + out, + x, + ctx.squeeze_B, + ctx.squeeze_C + ) + + return ( + du, + ddelta, + dA, + dB, + dC, + dD if D is not None else None, + dz if z is not None else None, + ddelta_bias if delta_bias is not None else None, + None, + None, + ) + +def custom_setup_context(ctx, inputs, output): + (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) = inputs + (final_out, last_state, out, x, squeeze_B, squeeze_C, has_z) = output + + ctx.delta_softplus = delta_softplus + ctx.squeeze_B = squeeze_B + ctx.squeeze_C = squeeze_C + ctx.has_z = has_z + + B = B.contiguous() + C = C.contiguous() + if squeeze_B and B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l").contiguous() + if squeeze_C and C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l").contiguous() + + if not has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, out) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + +torch.library.register_autograd( + "custom_ops::selective_scan_fwd", custom_bridge, setup_context=custom_setup_context +) + +def selective_scan_fn_custom_op( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: Optional[torch.Tensor], + z: Optional[torch.Tensor], + delta_bias: Optional[torch.Tensor], + delta_softplus: bool, + return_last_state: bool, +) -> torch.Tensor: + # Pass all arguments positionally, exactly in schema order: + final_out, last_state, _, _, _, _, _ = torch.ops.custom_ops.selective_scan_fwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + return_last_state + ) + + if return_last_state: + return final_out, last_state + else: + return final_out diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 8a834b3c..83ba3eca 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -9,9 +9,17 @@ from einops import rearrange from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref -from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref - - +from mamba_ssm.ops.selective_scan_interface_compilable import selective_scan_fn_custom_op + +@pytest.mark.parametrize( + "op_impl", + [ + selective_scan_fn, + selective_scan_fn_custom_op, + torch.compile(selective_scan_fn_custom_op), + ], + ids=["original", "custom", "compiled"], +) # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) @pytest.mark.parametrize('wtype', [torch.float32]) # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) @@ -35,7 +43,7 @@ @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, +def test_selective_scan(op_impl, is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, delta_softplus, return_last_state, seqlen, itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable @@ -92,7 +100,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z u_ref = u.detach().clone().requires_grad_() delta_ref = delta.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out, *rest = selective_scan_fn( + out, *rest = op_impl( u, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state @@ -144,104 +152,4 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z if has_z: assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) if has_delta_bias: - assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) - - -@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) -# @pytest.mark.parametrize('wtype', [torch.complex64]) -# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize("is_variable_C", [False, True]) -# @pytest.mark.parametrize("is_variable_C", [False]) -@pytest.mark.parametrize("is_variable_B", [False, True]) -# @pytest.mark.parametrize("is_variable_B", [True]) -def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): - device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - dim = 768 - dstate = 8 - dt_rank = 48 - is_complex = wtype == torch.complex64 - xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) - conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) - conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate - * (1 if not is_complex else 2), - dim, device=device, dtype=itype, requires_grad=True) - delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) - out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) - out_proj_bias = None - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_B else None) - C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_C else None) - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - B_proj_bias = None - C_proj_bias = None - xz_ref = xz.detach().clone().requires_grad_() - conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() - conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() - x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() - delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() - out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() - out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() - if out_proj_bias is not None else None) - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() if B is not None else None - C_ref = C.detach().clone().requires_grad_() if C is not None else None - D_ref = D.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias=delta_bias, delta_softplus=True) - out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, - delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, - A_ref, B_ref, C_ref, D_ref, - delta_bias=delta_bias_ref, delta_softplus=True) - # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - g = torch.randn_like(out) - out_ref.backward(g) - out.backward(g) - - print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') - print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') - if not is_variable_B: - print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') - if not is_variable_C: - print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') - print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') - print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') - print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') - print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') - print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') - print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') - print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') - - # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) - # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) - # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) - # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, - # atol=atolw if not is_variable_B else atol) - # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, - # atol=atolw if not is_variable_C else atol) - # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) - # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) \ No newline at end of file From 265d4e253acc744d407677a502032e7a5a085aee Mon Sep 17 00:00:00 2001 From: bhack Date: Wed, 18 Dec 2024 15:20:31 +0000 Subject: [PATCH 2/5] Fix removed code --- tests/ops/test_selective_scan.py | 103 ++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 83ba3eca..296629b7 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -9,6 +9,7 @@ from einops import rearrange from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref from mamba_ssm.ops.selective_scan_interface_compilable import selective_scan_fn_custom_op @pytest.mark.parametrize( @@ -152,4 +153,104 @@ def test_selective_scan(op_impl, is_variable_B, is_variable_C, varBC_groups, has if has_z: assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) if has_delta_bias: - assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) \ No newline at end of file + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + + +@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +# @pytest.mark.parametrize('wtype', [torch.complex64]) +# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('itype', [torch.float32]) +# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) +@pytest.mark.parametrize('seqlen', [128]) +@pytest.mark.parametrize("is_variable_C", [False, True]) +# @pytest.mark.parametrize("is_variable_C", [False]) +@pytest.mark.parametrize("is_variable_B", [False, True]) +# @pytest.mark.parametrize("is_variable_B", [True]) +def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 768 + dstate = 8 + dt_rank = 48 + is_complex = wtype == torch.complex64 + xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) + conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) + conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate + * (1 if not is_complex else 2), + dim, device=device, dtype=itype, requires_grad=True) + delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) + out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) + out_proj_bias = None + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) + if not is_variable_B else None) + C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) + if not is_variable_C else None) + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + B_proj_bias = None + C_proj_bias = None + xz_ref = xz.detach().clone().requires_grad_() + conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() + conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() + x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() + delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() + out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() + out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() + if out_proj_bias is not None else None) + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() if B is not None else None + C_ref = C.detach().clone().requires_grad_() if C is not None else None + D_ref = D.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias=delta_bias, delta_softplus=True) + out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, + delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, + A_ref, B_ref, C_ref, D_ref, + delta_bias=delta_bias_ref, delta_softplus=True) + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + if not is_variable_B: + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + if not is_variable_C: + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') + print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') + print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') + print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') + print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') + + # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + # atol=atolw if not is_variable_B else atol) + # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + # atol=atolw if not is_variable_C else atol) + # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) From f9945b2811940308b674b8e028f852a86bed5adb Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 20 Dec 2024 15:22:30 +0000 Subject: [PATCH 3/5] Add op_check tests --- .../selective_scan_interface_ compilable.py | 187 +++++++++++------- tests/ops/test_selective_scan.py | 25 +++ 2 files changed, 146 insertions(+), 66 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface_ compilable.py b/mamba_ssm/ops/selective_scan_interface_ compilable.py index 60a59b7e..e244a3a3 100644 --- a/mamba_ssm/ops/selective_scan_interface_ compilable.py +++ b/mamba_ssm/ops/selective_scan_interface_ compilable.py @@ -1,15 +1,15 @@ import torch -import torch.nn.functional as F from einops import rearrange from typing import Optional, Tuple -import selective_scan_cuda +from mamba_ssm.ops.selective_scan_interface import selective_scan_cuda @torch.library.custom_op( "custom_ops::selective_scan_fwd", device_types=["cuda"], mutates_args=(), + schema="(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor, Tensor, Tensor, bool, bool, bool)", ) def custom_selective_scan_fwd( u: torch.Tensor, @@ -22,28 +22,33 @@ def custom_selective_scan_fwd( delta_bias: Optional[torch.Tensor], delta_softplus: bool, return_last_state: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool, bool]: +): pass + @torch.library.register_fake("custom_ops::selective_scan_fwd") def custom_selective_scan_fwd_fake( - u, - delta, - A, - B, - C, - D, - z, - delta_bias, - delta_softplus, - return_last_state, + u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state ): - final_out = torch.empty_like(u) dstate = A.size(1) * (2 if A.is_complex() else 1) - last_state_fake = u.new_empty((u.size(0), u.size(1), dstate)) if return_last_state else u.new_empty(0) - out_fake = torch.empty_like(u) - x_fake = u.new_empty((u.size(0), u.size(1), u.size(2), 2 * dstate)) - return final_out, last_state_fake, out_fake, x_fake, False, False, z is not None + seqlen = u.size(2) + n_chunks = (seqlen + 2048 - 1) // 2048 + + squeeze_B = B.dim() == 3 + squeeze_C = C.dim() == 3 + has_z = z is not None + + final_out = torch.empty_like(delta) + out_fake = torch.empty_like(delta) + last_state_fake = ( + u.new_empty((u.size(0), u.size(1), dstate)) + if return_last_state + else u.new_empty(0) + ) + x_fake = u.new_empty((u.size(0), u.size(1), n_chunks, 2 * A.size(1)), dtype=A.dtype) + + return final_out, last_state_fake, out_fake, x_fake, squeeze_B, squeeze_C, has_z + @torch.library.register_kernel("custom_ops::selective_scan_fwd", "cuda") def custom_selective_scan_fwd_cuda( @@ -81,16 +86,23 @@ def custom_selective_scan_fwd_cuda( C = rearrange(C, "b dstate l -> b 1 dstate l").contiguous() squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd( + u, delta, A, B, C, D, z, delta_bias, delta_softplus + ) has_z = z is not None - final_out = rest[0].clone() if has_z else out.clone() + if has_z: + final_out = rest[0].clone() + else: + final_out = out.clone() last_state = x[:, :, -1, 1::2].clone() if return_last_state else u.new_empty(0) return final_out, last_state, out, x, squeeze_B, squeeze_C, has_z + @torch.library.custom_op( "custom_ops::selective_scan_bwd", device_types=["cuda"], mutates_args=(), + schema="(Tensor dout, Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, Tensor out, Tensor x, bool squeeze_B, bool squeeze_C, bool recompute_out_z) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, Tensor?)", ) def custom_selective_scan_bwd( dout: torch.Tensor, @@ -107,9 +119,11 @@ def custom_selective_scan_bwd( x: torch.Tensor, squeeze_B: bool, squeeze_C: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + recompute_out_z: bool, +): pass + @torch.library.register_fake("custom_ops::selective_scan_bwd") def custom_selective_scan_bwd_fake( dout, @@ -126,16 +140,33 @@ def custom_selective_scan_bwd_fake( x, squeeze_B, squeeze_C, + recompute_out_z, ): + # Here we just return shape-compatible fake tensors du = torch.empty_like(u) ddelta = torch.empty_like(delta) dA = torch.empty_like(A) - dB = torch.empty_like(B) - dC = torch.empty_like(C) - dD = torch.empty_like(D) if (D is not None and D.numel() > 0) else u.new_empty(0) - dz = torch.empty_like(z) if (z is not None and z.numel() > 0) else u.new_empty(0) - ddelta_bias = torch.empty_like(delta_bias) if (delta_bias is not None and delta_bias.numel() > 0) else u.new_empty(0) - return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias + + # Decide if variable B/C + is_variable_B = B.dim() > 3 + is_variable_C = C.dim() > 3 + + dB = torch.empty_like( + B, dtype=B.dtype + ) # If variable_B, still float32 is okay for fake + dC = torch.empty_like(C, dtype=C.dtype) + + dD = torch.empty_like(D) if (D is not None) else None + ddelta_bias_out = torch.empty_like(delta_bias) if (delta_bias is not None) else None + dz = torch.empty_like(z) if (z is not None) else None + + if squeeze_B and dB.numel() > 0: + dB = dB.squeeze(1) + if squeeze_C and dC.numel() > 0: + dC = dC.squeeze(1) + + return du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz + @torch.library.register_kernel("custom_ops::selective_scan_bwd", "cuda") def custom_selective_scan_bwd_cuda( @@ -153,68 +184,101 @@ def custom_selective_scan_bwd_cuda( x: torch.Tensor, squeeze_B: bool, squeeze_C: bool, + recompute_out_z: bool, ): if dout.stride(-1) != 1: dout = dout.contiguous() - B = B.contiguous() - C = C.contiguous() results = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, delta_softplus, False + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout, + x, + out, + None, + delta_softplus, + recompute_out_z, ) + has_z = z is not None if has_z: - du, ddelta, dA, dB, dC, dD, ddelta_bias, dz = results + du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz = results else: - du, ddelta, dA, dB, dC, dD, ddelta_bias = results - dz = u.new_empty(0) + du, ddelta, dA, dB, dC, dD, ddelta_bias_out = results + dz = None if squeeze_B and dB.numel() > 0: dB = dB.squeeze(1) if squeeze_C and dC.numel() > 0: dC = dC.squeeze(1) - return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias + return du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz + def custom_bridge(ctx, *grads): dout = grads[0] if grads else ctx.saved_tensors[0].new_empty(0) saved = ctx.saved_tensors + if not ctx.has_z: u, delta, A, B, C, D, delta_bias, x, out = saved z = None else: u, delta, A, B, C, D, z, delta_bias, x, out = saved - du, ddelta, dA, dB, dC, dD, dz, ddelta_bias = torch.ops.custom_ops.selective_scan_bwd( - dout, - u, - delta, - A, - B, - C, - D, - z, - delta_bias, - ctx.delta_softplus, - out, - x, - ctx.squeeze_B, - ctx.squeeze_C + du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz = ( + torch.ops.custom_ops.selective_scan_bwd( + dout, + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + ctx.delta_softplus, + out, + x, + ctx.squeeze_B, + ctx.squeeze_C, + False, + ) ) + # For optional inputs, return None if not provided in forward + if D is None: + dD = None + if z is None: + dz = None + if delta_bias is None: + ddelta_bias_out = None + + # Return gradients in the order of forward inputs: + # (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + # `delta_softplus` and `return_last_state` are bools -> gradient = None + d_delta_softplus = None + d_return_last_state = None + return ( du, ddelta, dA, dB, dC, - dD if D is not None else None, - dz if z is not None else None, - ddelta_bias if delta_bias is not None else None, - None, - None, + dD, + dz, + ddelta_bias_out, + d_delta_softplus, + d_return_last_state, ) + def custom_setup_context(ctx, inputs, output): (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) = inputs (final_out, last_state, out, x, squeeze_B, squeeze_C, has_z) = output @@ -236,10 +300,12 @@ def custom_setup_context(ctx, inputs, output): else: ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + torch.library.register_autograd( "custom_ops::selective_scan_fwd", custom_bridge, setup_context=custom_setup_context ) + def selective_scan_fn_custom_op( u: torch.Tensor, delta: torch.Tensor, @@ -252,20 +318,9 @@ def selective_scan_fn_custom_op( delta_softplus: bool, return_last_state: bool, ) -> torch.Tensor: - # Pass all arguments positionally, exactly in schema order: final_out, last_state, _, _, _, _, _ = torch.ops.custom_ops.selective_scan_fwd( - u, - delta, - A, - B, - C, - D, - z, - delta_bias, - delta_softplus, - return_last_state + u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state ) - if return_last_state: return final_out, last_state else: diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 296629b7..cb57ad9c 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -254,3 +254,28 @@ def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): # atol=atolw if not is_variable_C else atol) # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + +def test_selective_scan_opcheck(): + from torch.library import opcheck + + device = "cuda" + # small inputs for opcheck + u = torch.randn(1, 2, 8, device=device, requires_grad=True) + delta = torch.randn(1, 2, 8, device=device, requires_grad=True) + A = torch.randn(2, 8, device=device, requires_grad=True) + B = torch.randn(1, 1, 8, 8, device=device, requires_grad=True) + C = torch.randn(1, 1, 8, 8, device=device, requires_grad=True) + D = torch.randn(2, device=device, requires_grad=True) + z = torch.randn(1, 2, 8, device=device, requires_grad=True) + delta_bias = torch.randn(2, device=device, requires_grad=True) + delta_softplus = False + return_last_state = False + + # Run opcheck + result = opcheck( + torch.ops.custom_ops.selective_scan_fwd, + (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state), + test_utils=("test_schema", "test_faketensor", "test_aot_dispatch_dynamic", "test_autograd_registration"), + raise_exception=True + ) + print("Opcheck result:", result) From 8e4ec7293e25f3d32879278ccff04b5efa4d956a Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 20 Dec 2024 15:55:00 +0000 Subject: [PATCH 4/5] Add dispatch static --- tests/ops/test_selective_scan.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index cb57ad9c..27b99d85 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -257,7 +257,7 @@ def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): def test_selective_scan_opcheck(): from torch.library import opcheck - + device = "cuda" # small inputs for opcheck u = torch.randn(1, 2, 8, device=device, requires_grad=True) @@ -275,7 +275,13 @@ def test_selective_scan_opcheck(): result = opcheck( torch.ops.custom_ops.selective_scan_fwd, (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state), - test_utils=("test_schema", "test_faketensor", "test_aot_dispatch_dynamic", "test_autograd_registration"), - raise_exception=True + test_utils=( + "test_schema", + "test_faketensor", + "test_aot_dispatch_static", + "test_aot_dispatch_dynamic", + "test_autograd_registration", + ), + raise_exception=True, ) print("Opcheck result:", result) From 92aa16a14083be802e201470e758b8f749e0ec9b Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 20 Dec 2024 16:12:40 +0000 Subject: [PATCH 5/5] Fix import --- mamba_ssm/ops/selective_scan_interface_ compilable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface_ compilable.py b/mamba_ssm/ops/selective_scan_interface_ compilable.py index e244a3a3..de6e146b 100644 --- a/mamba_ssm/ops/selective_scan_interface_ compilable.py +++ b/mamba_ssm/ops/selective_scan_interface_ compilable.py @@ -2,8 +2,7 @@ from einops import rearrange from typing import Optional, Tuple -from mamba_ssm.ops.selective_scan_interface import selective_scan_cuda - +import selective_scan_cuda @torch.library.custom_op( "custom_ops::selective_scan_fwd",