From fdffb60694eb785ecb8aff07c1be4153b2d8106d Mon Sep 17 00:00:00 2001 From: Josiah Bjorgaard Date: Fri, 10 Jan 2025 11:14:16 -0800 Subject: [PATCH 1/3] Add Context Parallel Mamba2 --- mamba_ssm/modules/mamba2.py | 58 ++++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d47..c8a55fc9 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -29,11 +29,24 @@ from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.ops.triton.ssd_combined_cp import mamba_split_conv1d_scan_combined as mamba_split_conv1d_scan_combined_cp from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from mamba_ssm.modules.context_parallel import ContextParallelMixerLayer + + from huggingface_hub import PyTorchModelHubMixin +import torch.distributed as dist + +# Context Parallel - split input sequence +# Going to want to shard this outside of Mamba2 class, so it can run over multiple layers... +# auto_shard_seq = not force_ring_reduce_off and self.auto_shard_seq and is_distributed() +# mask = None +# (u, _), batch_sizes, num_sharded_batches = sharded_batch_to_sharded_seq(u, mask, self.ring_seq_size) +# End Context Parallel + class Mamba2(nn.Module, PyTorchModelHubMixin): def __init__( self, @@ -61,6 +74,7 @@ def __init__( layer_idx=None, # Absorb kwarg for general module process_group=None, sequence_parallel=True, + context_parallel=False, device=None, dtype=None, ): @@ -73,13 +87,14 @@ def __init__( self.expand = expand self.process_group = process_group self.sequence_parallel = sequence_parallel - self.world_size = 1 if process_group is None else process_group.size() - self.local_rank = 0 if process_group is None else process_group.rank() + #FIXME this is probably for sequence parallel For context parallel we just replace it later - we odn't want to divide innder dimensions by world size since we don't shard that - but we probably could... + self.world_size = 1 #if process_group is None else process_group.size() + self.local_rank = 0 #if process_group is None else process_group.rank() self.d_inner = (self.expand * self.d_model) // self.world_size assert self.d_inner * self.world_size == self.expand * self.d_model self.headdim = headdim self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size - assert ngroups % self.world_size == 0 + assert ngroups % self.world_size == 0 self.ngroups = ngroups // self.world_size assert self.d_ssm % self.headdim == 0 self.nheads = self.d_ssm // self.headdim @@ -91,12 +106,26 @@ def __init__( self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path self.layer_idx = layer_idx + self.context_parallel = context_parallel + + assert not (self.context_parallel and self.sequence_parallel) + if self.context_parallel or self.sequence_parallel and not self.process_group: + #TODO clean up process group passes along with world size/local rank here so one source of truth + assert torch.distributed.is_initialized() + self.process_group = torch.distributed.group.WORLD + self.world_size = torch.distributed.get_world_size() + self.local_rank = torch.distributed.get_rank() + + if self.context_parallel: + self.cpmixer = ContextParallelMixerLayer(padding=d_conv - 1, process_group=self.process_group) + else: + self.cpmixer = None # Order: [z, x, B, C, dt] d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - if self.process_group is None: + if not self.sequence_parallel: self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) - else: + else: #FIXME not sure why this has sequence parallel flag, why would use ColumnParallel without sequence parallel? self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs) @@ -144,7 +173,7 @@ def __init__( self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, group_size=self.d_ssm // ngroups, **factory_kwargs) - if self.process_group is None: + if not self.sequence_parallel: #FIXME not sure why this has sequence parallel flag, why would use RowParallel without sequence parallel? self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) else: self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias, @@ -175,14 +204,20 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param out, _, _ = self.step(u, conv_state, ssm_state) return out - zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) + if self.cpmixer: #Context parallel - transfer some tokens to mix in with the conv layer to the next GPU + u = self.cpmixer(u) + zxbcdt = self.in_proj(u) + #torch.save(zxbcdt,f'zxbcdt_{dist.get_rank() if dist.is_initialized() else 0}.pt') + #torch.save(self.norm.weight,f'norm_weight_{dist.get_rank() if dist.is_initialized() else 0}.pt') if seqlen_og is not None: zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) # If the model is loaded in fp16, without the .float() here, A might be -inf A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state) dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) + if self.use_mem_eff_path and inference_params is None: - out = mamba_split_conv1d_scan_combined( + fn = mamba_split_conv1d_scan_combined_cp if self.context_parallel else mamba_split_conv1d_scan_combined + out = fn( zxbcdt, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, @@ -199,13 +234,14 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param headdim=None if self.D_has_hdim else self.headdim, ngroups=self.ngroups, norm_before_gate=self.norm_before_gate, + process_group=self.process_group, **dt_limit_kwargs, ) if seqlen_og is not None: out = rearrange(out, "b l d -> (b l) d") - if self.process_group is not None: - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - out = reduce_fn(out, self.process_group) + #if self.process_group is not None: #FIXME this was here before adding context parallel + # reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + # out = reduce_fn(out, self.process_group) else: d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 z0, x0, z, xBC, dt = torch.split( From 2df17a7f39f930df11a48449f942c0c21a76dbc8 Mon Sep 17 00:00:00 2001 From: Josiah Bjorgaard Date: Fri, 10 Jan 2025 11:14:50 -0800 Subject: [PATCH 2/3] added --- mamba_ssm/ops/triton/ssd_combined_cp.py | 1191 +++++++++++++++++++++++ 1 file changed, 1191 insertions(+) create mode 100644 mamba_ssm/ops/triton/ssd_combined_cp.py diff --git a/mamba_ssm/ops/triton/ssd_combined_cp.py b/mamba_ssm/ops/triton/ssd_combined_cp.py new file mode 100644 index 00000000..0a770f48 --- /dev/null +++ b/mamba_ssm/ops/triton/ssd_combined_cp.py @@ -0,0 +1,1191 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +from typing import Optional + +import math +from packaging import version + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.amp import custom_bwd, custom_fwd + +import torch.distributed as dist + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn + import causal_conv1d_cuda +except ImportError: + causal_conv1d_fn, causal_conv1d_cuda = None, None + +from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen +from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd +from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable +from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev +from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd +from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, + b_ptr, dstates_ptr, + dx_ptr, ddt_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_D_head, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + # However, we're getting error with the Triton compiler 2.1.0 for that code path: + # Unexpected mma -> mma layout conversion + # Triton 2.2.0 fixes this + offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) + if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) * scale[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate + acc *= scale[:, None] + + # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + # dt_ptrs = dt_ptr + offs_m * stride_dt_csize + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + # ddt = tl.sum(acc * x, axis=1) * dt_m + # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit + K_MIN = pid_m * BLOCK_SIZE_M + cb_ptrs += K_MIN * stride_cb_csize_k + dout_ptrs += K_MIN * stride_dout_seqlen + dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize + for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): + k = tl.multiple_of(k, BLOCK_SIZE_K) + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, + # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. + # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. + # This will cause NaN in acc, and hence NaN in dx and ddt. + mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + dD = tl.sum(dout_res * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + else: + dD = tl.sum(dout_res * x) + tl.store(dD_ptr, dD) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + +def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + if dx is None: + dx = torch.empty_like(x) + else: + assert dx.shape == x.shape + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( + x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + D.stride(0) if D is not None else 0, + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22 + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return dx, ddt.to(dtype=dt.dtype), dD + + +def _transfer(tensor, rank, send_rank, recv_rank, group=torch.distributed.group.WORLD): + """ + Sends tensor to root process, which store it in tensor_list. + """ + if rank not in [send_rank, recv_rank]: + return tensor + if rank == send_rank: + dist.send(tensor, recv_rank, group=group) + else: + dist.recv(tensor, send_rank, group=group) + return tensor + + +def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, + initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, + process_group = torch.distributed.group.WORLD, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads), print(dt.shape, x.shape) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group) + # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) + # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) + # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) + def _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C): + states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, + "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) + states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + return states, final_states + + if world_size > 1: + if rank == 0: + states, final_states = _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C) + dist.barrier() + #print('passing states sequentially multi-gpu') + for rank1, rank2 in zip(range(world_size-1),range(1,world_size)): + #print(f"{rank} - {rank1} - {rank2}") + if rank in [rank1, rank2]: + #print(f"transfer {rank1}:{rank2}") + if rank == rank2: + final_states = torch.zeros_like(states[:,-1]) + initial_states = _transfer(final_states, rank, rank1, rank2, process_group) + if rank == rank2: + #print(f"state passing {rank2}") + states, final_states = _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C) + dist.barrier() + else: + states, final_states = _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C) + #print("Done state passing") + #torch.save(final_states,f"final_states_{dist.get_rank()}.pt") + #torch.save(states,f"passed_states_{dist.get_rank()}.pt") + #torch.save(states,f"passed_states_{dist.get_rank()}.pt") + # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) + #print(f"Forward sizes: {x.shape} {out.shape}") + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), + cu_seqlens, states.squeeze(0)) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, + dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, + dt_limit=(0.0, float("inf")), + dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False, + process_group = torch.distributed.group.WORLD, + ): + #print('chunk_scan_combined_bwd') + rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group) + + if dout.stride(-1) != 1: + dout = dout.contiguous() + batch, seqlen, nheads, headdim = x.shape + nchunks = math.ceil(seqlen / chunk_size) + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim), print(dout.shape, batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert C.shape == B.shape + assert out.shape == x.shape + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if dx is not None: + assert dx.shape == x.shape, print(f"{dx.shape = }, {x.shape = }") + if dB is not None: + assert dB.shape == B.shape + dB_given = dB + else: + dB_given = torch.empty_like(B) + if dC is not None: + assert dC.shape == C.shape + dC_given = dC + else: + dC_given = torch.empty_like(C) + if dz is not None: + assert z is not None + assert dz.shape == z.shape, print(f"{dz.shape = },{z.shape = }") + if ddt is not None: + assert ddt.shape == dt.shape + ddt_given = ddt + else: + ddt_given = torch.empty_like(dt) + # TD: For some reason Triton (2.1.0 and 2.2.0) errors with + # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. + dt_in = dt.clone() + dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, + dt_limit=dt_limit) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + #torch.save(states,f"states_{dist.get_rank()}.pt") + #torch.save(dA_cumsum,f"dA_cumsum_{dist.get_rank()}.pt") + def _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C): + states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size) + #states = rearrange(states, "... (p n) -> ... p n", n=dstate) + states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + return states, final_states + + if world_size > 1: + if rank == 0: + states, final_states = _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C) + dist.barrier() + # print('passing states sequentially multi-gpu') + for rank1, rank2 in zip(range(world_size - 1), range(1, world_size)): + # print(f"{rank} - {rank1} - {rank2}") + if rank in [rank1, rank2]: + # print(f"transfer {rank1}:{rank2}") + if rank == rank2: + final_states = torch.zeros_like(states[:, -1]) + initial_states = _transfer(final_states, rank, rank1, rank2, process_group) + if rank == rank2: + # print(f"state passing {rank2}") + states, final_states = _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C) + dist.barrier() + #initial_states = None + else: + states, final_states = _state_passing_fwd_wrap(states, dA_cumsum, initial_states, seq_idx, chunk_size, C) + #torch.save(dout, f"dout_{dist.get_rank()}.pt") + #torch.save(states,f"passed_states_{dist.get_rank()}.pt") + #torch.save(final_states,f"final_states_{dist.get_rank()}.pt") + + if z is not None: + #print(f"z not none, chunk scan bwd dz, {dist.get_rank()}") + dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output) + outz = rest[0] if recompute_output else out + else: + dz = None + outz = out + + #FIXME: if this is not the GPU where a gradient is calculated, dout is zero, + #FIXME: But the rest of the pass still needs calculation if rank < grad rank (causal ordering) + dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) + # dstates has length nchunks, containing the gradient to initial states at index 0 and + # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) + # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states + # will be used in matmul in the next kernels. + + #torch.save(dstates,f"dstates_{dist.get_rank()}.pt") + + def _state_passing_bwd_wrap(states, dA_cumsum, dstates, dfinal_states, initial_states, seq_idx, chunk_size, x): + dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(dstates, "... p n -> ... (p n)"), + dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, + seq_idx=seq_idx, + has_initial_states=initial_states is not None, + dstates_dtype=x.dtype, + states_dtype=x.dtype, + chunk_size=chunk_size, + ) + # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and + # gradient to the final states at index (nchunks - 1) + # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) + # The final states is not stored. + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) + dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None + return states, dstates, dinitial_states, ddA_chunk_cumsum + + #Reusue initial states from fwd recalc above, hopefully they haven't been overwritten + if world_size > 1: + if rank == world_size-1: + states, dstates, dinitial_states, ddA_chunk_cumsum = _state_passing_bwd_wrap(states, dA_cumsum, dstates, dfinal_states, initial_states, seq_idx, chunk_size, x) + dist.barrier() + #print(f'passing states sequentially multi-gpu on {dist.get_rank()}') + for rank1, rank2 in zip(range(world_size - 1,0,-1), range(world_size-2,-1,-1)): + #print(f"d_on {rank} - {rank1} - {rank2}") + if rank in [rank1, rank2]: + #print(f"transfer {rank1}:{rank2}") + if rank == rank2: + dinitial_states = torch.zeros_like(states[:, -1]) + dfinal_states = _transfer(dinitial_states, rank, rank1, rank2, process_group) + if rank == rank2: + #print(f"state passing backward {rank2}") + states, dstates, dinitial_states, ddA_chunk_cumsum = _state_passing_bwd_wrap(states, dA_cumsum, dstates, + dfinal_states, + initial_states, seq_idx, + chunk_size, x) + dist.barrier() + #print('done') + dinitial_states = None #FIXME hack to because no initial states are used except between GPUs + else: + states, dstates, dinitial_states, ddA_chunk_cumsum = _state_passing_bwd_wrap(states, dA_cumsum, dstates, + dfinal_states, initial_states, + seq_idx, chunk_size, x) + + #torch.save(states,f"bw_states_{dist.get_rank()}.pt") + #torch.save(dstates,f"dpassed_states_{dist.get_rank()}.pt") + #torch.save(dinitial_states,f"dinitial_states_{dist.get_rank()}.pt") + #torch.save(ddA_chunk_cumsum,f"ddA_chunk_cumsum_{dist.get_rank()}.pt") + + dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) + + #torch.save(dx, f"dx_{dist.get_rank()}.pt") + #torch.save(ddt,f"ddt_{dist.get_rank()}.pt") + #torch.save(dD_from_x,f"fdD_from_x_{dist.get_rank()}.pt") + + + # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) + dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) + # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups) + + #torch.save(dB, f"dB_{dist.get_rank()}.pt") + #torch.save(ddA_next,f"ddA_next_{dist.get_rank()}.pt") + #torch.save(dC,f"dC_{dist.get_rank()}.pt") + #torch.save(ddA_cumsum_prev,f"ddA_cumsum_prev_{dist.get_rank()}.pt") + + # Computing ddA with the dcb kernel is much slower, so we're not using it for now + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) + _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) + # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate + # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 + if z is None: + dD = dD_from_x + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might + # be a lot of underflow. + + #torch.save(dCB, f"dCB_{dist.get_rank()}.pt") + #torch.save(C,f"C_{dist.get_rank()}.pt") + #torch.save(B,f"B_{dist.get_rank()}.pt") + #torch.save(x,f"x_{dist.get_rank()}.pt") + + # This is already done as part of bwd_dC kernel + # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) + ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum + ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) + # This is already done as part of bwd_dB kernel + # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) + # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] + ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) + ddA += ddA_next + ddA_prev + + ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given) + + # These 2 lines are just to test ddt and dA being computed by old code + # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z) + # ddt_given.copy_(ddt) + #torch.save(dC_given,f"dC_given_{dist.get_rank()}.pt") + #torch.save(dB_given,f"dB_given_{dist.get_rank()}.pt") + #torch.save(ddt_given,f"ddt_given_{dist.get_rank()}.pt") + #torch.save(dz,f"dz_{dist.get_rank()}.pt") + return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) + #print('returning') + return return_vals if not recompute_output else (*return_vals, outz) + + +def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): + """ + Argument: + dout: (batch, seqlen, nheads, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size) + A: (nheads) or (dim, dstate) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + import selective_scan + + batch, seqlen, nheads, headdim = x.shape + chunk_size = dt.shape[-1] + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + x = rearrange(x, "b l h p -> b (h p) l") + squeeze_dt = dt.dim() == 4 + if dt.dim() == 4: + dt = repeat(dt, "b h c l -> b h p c l", p=headdim) + dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim) + squeeze_A = A.dim() == 1 + if A.dim() == 1: + A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) + else: + A = A.to(dtype=torch.float32) + B = rearrange(B, "b l g n -> b g n l") + C = rearrange(C, "b l g n -> b g n l") + if D is not None: + if D.dim() == 2: + D = rearrange(D, "h p -> (h p)") + else: + D = repeat(D, "h -> (h p)", p=headdim) + if z is not None: + z = rearrange(z, "b l h p -> b (h p) l") + + if x.stride(-1) != 1: + x = x.contiguous() + if dt.stride(-1) != 1: + dt = dt.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False) + if z is not None: + out = rest[0] + else: + out = None + + dout = rearrange(dout, "b l h p -> b (h p) l") + + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + _, ddt, dA, *rest = selective_scan.bwd( + x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False, + False # option to recompute out_z, not used here + ) + ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size) + if squeeze_dt: + ddt = ddt.float().sum(dim=2) + if squeeze_A: + dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2)) + return ddt, dA + + +class MambaChunkScanCombinedFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, + dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False, + process_group = torch.distributed.group.WORLD): + ctx.process_group = process_group + ctx.dt_dtype = dt.dtype + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) + ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) + ctx.dt_softplus = dt_softplus + ctx.chunk_size = chunk_size + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + ctx.return_varlen_states = return_varlen_states + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) + + @staticmethod + def backward(ctx, dout, *args): + out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors + assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" + dfinal_states = args[0] if ctx.return_final_states else None + dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) + return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None + + +def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) + + +def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dA = dt * rearrange(A, "h -> h 1 1") + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True) + # 2. Pass the state to all the chunks by weighted cumsum. + states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], + "... (p n) -> ... p n", n=dstate) + # 3. Compute the output for each chunk + out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z) + return out + + +def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state_ref(B, x, dt, dA_cumsum) + states_dtype = states.dtype + if states.dtype not in [torch.float32, torch.float64]: + states = states.to(torch.float32) + # 2. Pass the state to all the chunks by weighted cumsum. + # state_passing_ref is much less numerically stable + states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], + "... (p n) -> ... p n", n=dstate) + states = states.to(states_dtype) + # 3. Compute the output for each chunk + out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) + return out + + +def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) + A: (nheads) or (dim, dstate) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) or (nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + x = rearrange(x, "b l h p -> b (h p) l") + if dt.dim() == 3: + dt = repeat(dt, "b l h -> b l h p", p=headdim) + dt = rearrange(dt, "b l h p -> b (h p) l") + if A.dim() == 1: + A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) + else: + A = A.to(dtype=torch.float32) + B = rearrange(B, "b l g n -> b g n l") + C = rearrange(C, "b l g n -> b g n l") + if D is not None: + if D.dim() == 2: + D = rearrange(D, "h p -> (h p)") + else: + D = repeat(D, "h -> (h p)", p=headdim) + if z is not None: + z = rearrange(z, "b l h p -> b (h p) l") + if dt_bias is not None: + if dt_bias.dim() == 1: + dt_bias = repeat(dt_bias, "h -> h p", p=headdim) + dt_bias = rearrange(dt_bias, "h p -> (h p)") + if dt_limit != (0.0, float("inf")): + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "d -> d 1") + if dt_softplus: + dt = F.softplus(dt) + dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype) + dt_bias = None + dt_softplus = None + out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus) + return rearrange(out, "b (h p) l -> b l h p", p=headdim) + + +def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None, + dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + activation="silu", headdim=None, ngroups=1): + """ + Argument: + xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) + A: (nheads) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, dim) + dt_bias: (nheads) or (nheads, headdim) + headdim: if D is 1D and z is None, headdim must be passed in + Return: + out: (batch, seqlen, dim) + """ + batch, seqlen, nheads = dt.shape[:3] + assert nheads % ngroups == 0 + if z is not None: + dim = z.shape[-1] + assert dim % nheads == 0 + headdim = dim // nheads + else: + if D.dim() == 1: + assert headdim is not None + else: + headdim = D.shape[1] + dim = nheads * headdim + xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), + "b d s -> b s d") + dstate = (xBC.shape[-1] - dim) // ngroups // 2 + x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None + out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + return rearrange(out, "b s h p -> b s (h p)") + + +class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", + rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, + ngroups=1, norm_before_gate=True, process_group = torch.distributed.group.WORLD): + assert activation in [None, "silu", "swish"] + rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group) + lb = 0 if rank == 0 else conv1d_weight.shape[1] - 1 #Added for context parallel + if D.dim() == 1: + assert headdim is not None + nheads, = D.shape + else: + nheads, headdim = D.shape + batch, seqlen, _ = zxbcdt.shape + dim = nheads * headdim + assert nheads % ngroups == 0 + dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2 + d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2 + assert d_nonssm >= 0 + assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads) + assert dt_bias.shape == (nheads,) + assert A.shape == (nheads,) + zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) + zx0, z, dt = zx0[:,lb:].contiguous(), z[:,lb:].contiguous(), dt[:,lb:].contiguous() # Context parallel + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + + seqlen -= lb #context parallel + lb = 0 if rank == 0 else conv1d_weight.shape[1] - 1 #Added for context parallel + xBC_conv = rearrange( + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), + "b d s -> b s d" + )[:, lb:, :].contiguous() + x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None + if rmsnorm_weight is None: + out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, + D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, + dt_softplus=True, dt_limit=dt_limit, + process_group = process_group) + out = rearrange(out, "b s h p -> b s (h p)") + rstd = None + if d_nonssm > 0: + out = torch.cat([_swiglu_fwd(zx0), out], dim=-1) + else: + out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, + initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit, + process_group = process_group) + # reshape input data into 2D tensor + x_rms = rearrange(out_x, "b s h p -> (b s) (h p)") + z_rms = rearrange(z, "b s h p -> (b s) (h p)") + rmsnorm_weight = rmsnorm_weight.contiguous() + if d_nonssm == 0: + out = None + else: + out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device) + out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d") + _swiglu_fwd(zx0, out=out01[..., :d_nonssm]) + out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out, + group_size=dim // ngroups, + norm_before_gate=norm_before_gate, is_rms_norm=True) + if d_nonssm == 0: + out = rearrange(out, "(b s) d -> b s d", b=batch) + else: + out = out01 + ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None + if outproj_weight is not None: + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + out, outproj_weight = out.to(dtype), outproj_weight.to(dtype) + outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None + out = F.linear(out, outproj_weight, outproj_bias) + else: + assert outproj_bias is None + ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias, + out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias) + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + ctx.activation = activation + ctx.rmsnorm_eps = rmsnorm_eps + ctx.norm_before_gate = norm_before_gate + ctx.chunk_size = chunk_size + ctx.headdim = headdim + ctx.ngroups = ngroups + ctx.process_group = process_group + return out if not return_final_states else (out, final_states) + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, dout, *args): + zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + process_group = ctx.process_group + rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group) + headdim = ctx.headdim + nheads = D.shape[0] + dim = nheads * headdim + assert nheads % ctx.ngroups == 0 + dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2 + d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2 + assert d_nonssm >= 0 + recompute_output = outproj_weight is not None + if recompute_output: + out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype) + out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1) + zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) + # Recompute x, B, C + lb = 0 if rank == 0 else conv1d_weight.shape[1] - 1 #Added for context parallel + zx0, z, dt = zx0[:,lb:], z[:,lb:], dt[:,lb:] # Context parallel + #torch.save(z,f'z_{dist.get_rank()}.pt') + xBC_conv = rearrange( + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), + "b d s -> b s d" + )[:, lb:, :].contiguous() + x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups) + + #torch.save(zxbcdt, f"zxbcdt_{dist.get_rank()}.pt") + #torch.save(xBC_conv, f"xBC_conv_{dist.get_rank()}.pt") + #Create empty output tensors for gradients + #Notes added on whether the memory locations are preserved or not - this is useful for later optimization + dzxbcdt = torch.empty_like(zxbcdt) + dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) #This doesn't make a copy of the memory, it returns a view + dzx0, dz, ddt_given = dzx0[:,lb:], dz[:,lb:], ddt_given[:,lb:] #This makes a copy of the memory... + dxBC = torch.empty_like(xBC_conv) #xBC is memory from zxbcdt and has sequence in it whereas _conv is already removed padding + dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + #dxBC = F.pad(dxBC, (0,0,lb,0), mode='constant', value=None) + #Now, dx, dB, and dC will pass data into dxBC, + # dzx0, dz, and ddt_given won't affect dzxbcdt, + # dxBC_given will affect dzxbcdt + #dxBC = torch.empty_like(xBC) + #dx, dB, dC = torch.split(dxBC[:, lb:], [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) + dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads) + dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups) + dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups) + + if outproj_weight is not None: + dout_og = dout + dout = F.linear(dout, outproj_weight.t()) + if d_nonssm > 0: + dout0, dout = dout.split([d_nonssm, dim], dim=-1) + _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute) + + dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim) + if rmsnorm_weight is None: + raise Exception("Deleted path here") #This could be added back later but probably needs some testing + else: + batch = dout.shape[0] + dy_rms = rearrange(dout, "b s h p -> (b s) (h p)") #dout (input) + dz = rearrange(dz, "b l d -> (b l) d") #dz (copied) + x_rms = rearrange(out, "b s h p -> (b s) (h p)") #out (calculated) + z_rms = rearrange(z, "b s h p -> (b s) (h p)") #z (calculated) + + #TODO - dz here doesn't directly write into xdzbcdt so needs to be copied later + out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None + dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, + None, rstd, z_rms, group_size=dim//ctx.ngroups, + norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, + recompute_output=recompute_output, dz=dz, + out=out1_recompute if recompute_output else None) + + out_for_linear = out_recompute if recompute_output else None + dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim) + dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, + dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, + process_group = process_group + ) #TODO - dx, dB, dC, and ddt here don't directly write into xdzbcdt so needs to be copied later + + #for k,v in {'dx':dx,'dA':dA,'dB':dB,'dC':dC,'dxBC_p':dxBC + # }.items(): + # torch.save(v,f'{k}_{dist.get_rank()}.pt') + + + if outproj_weight is not None: + doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear) + doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None + else: + doutproj_weight, doutproj_bias = None, None + dxBC_given = rearrange(dxBC_given, "b s d -> b d s") + + #This line is added for context parallel since the prepended tokens are dropped out of the conv layer in the forward pass + dxBC = F.pad(dxBC, (0,0,lb,0), mode='constant', value=0.0) + #print(f"{dxBC_given.shape = }, {xBC.shape = }, {dxBC.shape = }") + + dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"] + ) + + #These are added for context parallel (the coping into dzxbcdt) + #dxBC_given, dweight, dbias = dxBC, None, None #Context parallel test + #dzxbcdt[:,:,:2*d_nonssm] = F.pad(dzx0, (0,0,lb,0), mode='constant', value=0.0) + dzxbcdt[ :, :, 2 * d_nonssm + dim + dim + 2 * ctx.ngroups * dstate:] = F.pad(ddt, (0,0,lb,0), mode='constant',value=0.0) + dzxbcdt[ :,:, 2 * d_nonssm: dim] = F.pad(rearrange(dz, '(b l) d -> b l d', b = batch), (0,0,lb,0), mode='constant', value=0.0) + + dxBC_given = rearrange(dxBC_given, "b d s -> b s d") + #torch.save(dxBC,f'dxBC_{dist.get_rank()}.pt') + #torch.save(dxBC_given,f'dxBC_given_{dist.get_rank()}.pt') + #torch.save(dzxbcdt,f'dzxbcdt_{dist.get_rank()}.pt') + return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None, None + + +def mamba_split_conv1d_scan_combined(zxbcdt, + conv1d_weight, + conv1d_bias, + dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), + return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, + outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True, + process_group = torch.distributed.group.WORLD + ): + """ + Argument: + zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt_bias: (nheads,) + A: (nheads) + D: (nheads, headdim) or (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen), int32 + rmsnorm_weight: (dim,) + outproj_weight: (out_dim, dim) + outproj_bias: (out_dim,) + headdim: if D is 1D, headdim must be passed in + norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) + Return: + out: (batch, seqlen, dim) + """ + return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, + seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, + outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate, process_group) + + +def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): + """ + Argument: + zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt_bias: (nheads,) + A: (nheads) + D: (nheads, headdim) or (nheads,) + rmsnorm_weight: (dim,) + outproj_weight: (out_dim, dim) + outproj_bias: (out_dim,) + headdim: if D is 1D, headdim must be passed in + norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) + Return: + out: (batch, seqlen, dim) + """ + if D.dim() == 1: + assert headdim is not None + nheads, = D.shape + else: + nheads, headdim = D.shape + assert nheads % ngroups == 0 + batch, seqlen, _ = zxbcdt.shape + dim = nheads * headdim + dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2 + assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) + assert dt_bias.shape == (nheads,) + assert A.shape == (nheads,) + if rmsnorm_weight is not None: + assert rmsnorm_weight.shape == (dim,) + z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1) + xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), + "b d s -> b s d") + x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) + out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), + z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit) + out = rearrange(out, "b s h p -> b s (h p)") + if rmsnorm_weight is not None: + out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps, + norm_before_gate=norm_before_gate) + if outproj_weight is not None: + out = F.linear(out, outproj_weight, outproj_bias) + return out + From bc53b782929efcb66f972a5a59140d10f9d71f01 Mon Sep 17 00:00:00 2001 From: Josiah Bjorgaard Date: Fri, 10 Jan 2025 11:16:02 -0800 Subject: [PATCH 3/3] Add missing file --- mamba_ssm/modules/context_parallel.py | 83 +++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 mamba_ssm/modules/context_parallel.py diff --git a/mamba_ssm/modules/context_parallel.py b/mamba_ssm/modules/context_parallel.py new file mode 100644 index 00000000..a080ddba --- /dev/null +++ b/mamba_ssm/modules/context_parallel.py @@ -0,0 +1,83 @@ +from typing import Optional + +import torch +from torch import nn, Tensor +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank, group): + assert send_to_rank is not None or receive_from_rank is not None + ops = [] + if send_to_rank is not None: + ops.append(dist.P2POp(dist.isend, x, send_to_rank, group)) + if receive_from_rank is not None: + ops.append(dist.P2POp(dist.irecv, receive_buffer, receive_from_rank, group)) + + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + dist.barrier() + +class ContextParallelMixerFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x, padding=0, process_group=torch.distributed.group.WORLD): + #Prepends the last n_padding tokens from layer_n to layer_{n+1} + #These are mixed into subsequent tokens of layer n+1 by convolution, but their index is then discarded + # the convolution is causal, so the mixing only goes in one direction + rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group) + if world_size == 1: + return x + + send_to_rank = rank + 1 if rank < world_size - 1 else None + receive_from_rank = rank - 1 if rank > 0 else None + #print('dist', rank, 'send',send_to_rank, 'recieve',receive_from_rank) + #_, pre_tokens = x.split(x.shape[1]-self.padding, dim=1) + pre_tokens = x[:,-padding:].contiguous() + #print('dist',rank,'padding',padding) + assert pre_tokens.shape[1] == padding + receive_buffer = torch.zeros_like(pre_tokens, requires_grad=True).contiguous() #TODO this isn't used by rank=0 + send_and_receive_(pre_tokens, receive_buffer, send_to_rank, receive_from_rank, process_group) + if rank > 0: + x = F.pad(x, (0, 0, padding, 0), 'constant', 0) + x[:, :padding] = receive_buffer + #print('x', rank, x.shape) + ctx.padding=padding + ctx.process_group = process_group + return x + + @staticmethod + def backward(ctx, grad_x): + """ + grad x is input with the padding tokens from the next layer + the input of forward is not padded, this gradient needs to be popped and transfered + to the previous layer... + """ + process_group = ctx.process_group + rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group) + padding = ctx.padding + #print('grad_x', rank, grad_x.shape) + if world_size == 1: + return grad_x, None + send_to_rank = rank -1 if rank > 0 else None + receive_from_rank = rank + 1 if rank < world_size - 1 else None + pre_tokens_grad = grad_x[:, :padding].contiguous() + if rank > 0: + grad_x_out = grad_x[:, padding:].contiguous() + else: + grad_x_out = grad_x.clone() + assert pre_tokens_grad.shape[1] == ctx.padding + receive_buffer = torch.zeros_like(pre_tokens_grad).contiguous() #TODO this isn't used by rank=0 + send_and_receive_(pre_tokens_grad, receive_buffer, send_to_rank, receive_from_rank, process_group) + if rank < world_size -1: + grad_x_out[:, -padding:] += receive_buffer + return grad_x_out, None, None + +class ContextParallelMixerLayer(nn.Module): + def __init__(self, padding=0, process_group=torch.distributed.group.WORLD): + super(ContextParallelMixerLayer, self).__init__() + self.padding = padding + self.process_group = process_group + + def forward(self, x): + return ContextParallelMixerFn.apply(x, self.padding, self.process_group)