diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index fae2257a..9f7187a6 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -26,20 +26,25 @@ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +from typing import Dict, List, Optional, Tuple, Union +import torch +from torch import nn, Tensor + + def create_block( - d_model, - d_intermediate, - ssm_cfg=None, - attn_layer_idx=None, - attn_cfg=None, - norm_epsilon=1e-5, - rms_norm=False, - residual_in_fp32=False, - fused_add_norm=False, - layer_idx=None, - device=None, - dtype=None, -): + d_model: int, + d_intermediate: int, + ssm_cfg: Optional[Dict] = None, + attn_layer_idx: Optional[List[int]] = None, + attn_cfg: Optional[Dict] = None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + residual_in_fp32: bool = False, + fused_add_norm: bool = False, + layer_idx: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: if ssm_cfg is None: ssm_cfg = {} if attn_layer_idx is None: @@ -88,7 +93,7 @@ def _init_weights( n_layer, initializer_range=0.02, # Now only used for embedding layer. rescale_prenorm_residual=True, - n_residuals_per_layer=1, # Change to 2 if we have MLP + n_residuals_per_layer=1, ): if isinstance(module, nn.Linear): if module.bias is not None: @@ -122,16 +127,16 @@ def __init__( n_layer: int, d_intermediate: int, vocab_size: int, - ssm_cfg=None, - attn_layer_idx=None, - attn_cfg=None, + ssm_cfg: Optional[Dict] = None, + attn_layer_idx: Optional[List[int]] = None, + attn_cfg: Optional[Dict] = None, norm_epsilon: float = 1e-5, rms_norm: bool = False, - initializer_cfg=None, - fused_add_norm=False, - residual_in_fp32=False, - device=None, - dtype=None, + initializer_cfg: Optional[Dict] = None, + fused_add_norm: bool = False, + residual_in_fp32: bool = False, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -187,7 +192,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, inference_params=None, **mixer_kwargs): + def forward( + self, + input_ids: Tensor, + inference_params = None, + **mixer_kwargs + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: @@ -213,13 +223,12 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs): class MambaLMHeadModel(nn.Module, GenerationMixin): - def __init__( self, config: MambaConfig, - initializer_cfg=None, - device=None, - dtype=None, + initializer_cfg: Optional[Dict] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, ) -> None: self.config = config d_model = config.d_model @@ -271,7 +280,14 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): + def forward( + self, + input_ids: Tensor, + position_ids: Optional[Tensor] = None, + inference_params = None, + num_last_tokens: int = 0, + **mixer_kwargs + ) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d47..6fd8d893 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -10,18 +10,31 @@ try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None, None +except ImportError as e: + raise ImportError( + "causal_conv1d package not found. Please install it with: " + "pip install causal-conv1d>=1.4.0" + ) from e try: from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states -except ImportError: +except ImportError as e: causal_conv1d_varlen_states = None + import warnings + warnings.warn( + "causal_conv1d_varlen module not found. Variable length sequences will not be supported. " + "Install the latest causal_conv1d package for full functionality." + ) try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: +except ImportError as e: selective_state_update = None + import warnings + warnings.warn( + "selective_state_update module not found. Performance may be degraded. " + "Make sure to install with the 'triton' extra: pip install mamba-ssm[triton]" + ) from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated @@ -221,9 +234,12 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) else: assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package" - assert batch == 1, "varlen inference only supports batch dimension 1" + # The 'batch' variable here might be misleading when cu_seqlens is used. + # The actual number of sequences is cu_seqlens.shape[0] - 1. + # conv_state is already shaped (inference_batch, ...). + # xBC should be (total_tokens, features) when cu_seqlens is present. conv_varlen_states = causal_conv1d_varlen_states( - xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] + xBC, cu_seqlens, state_len=conv_state.shape[-1] ) conv_state.copy_(conv_varlen_states) assert self.activation in ["silu", "swish"] @@ -308,16 +324,55 @@ def step(self, hidden_states, conv_state, ssm_state): # SSM step if selective_state_update is None: - assert self.ngroups == 1, "Only support ngroups=1 for this inference code path" + assert self.nheads % self.ngroups == 0, "nheads must be divisible by ngroups for PyTorch step fallback" + k = self.nheads // self.ngroups + # Discretize A and B + # dt is already (batch, nheads) from xBC split and projection dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) - dA = torch.exp(dt * A) # (batch, nheads) - x = rearrange(x, "b (h p) -> b h p", p=self.headdim) - dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) - ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) - y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) - y = y + rearrange(self.D.to(dtype), "h -> h 1") * x - y = rearrange(y, "b h p -> b (h p)") + # A is (nheads,) + + # Reshape for grouped operations + # x: (B, d_ssm) -> (B, ngroups, k, headdim) + x_r = rearrange(x, "b (g k p) -> b g k p", g=self.ngroups, k=k, p=self.headdim) + # dt: (B, nheads) -> (B, ngroups, k) + dt_r = rearrange(dt, "b (g k) -> b g k", g=self.ngroups, k=k) + # A: (nheads,) -> (ngroups, k) + A_r = rearrange(A, "(g k) -> g k", g=self.ngroups, k=k) + # dA: (B, ngroups, k) + dA_r = torch.exp(dt_r * A_r.unsqueeze(0)) + + # B: (B, ngroups * d_state) -> (B, ngroups, d_state) + B_r = rearrange(B, "b (g n) -> b g n", g=self.ngroups) + # C: (B, ngroups * d_state) -> (B, ngroups, d_state) + C_r = rearrange(C, "b (g n) -> b g n", g=self.ngroups) + # ssm_state: (B, nheads, headdim, d_state) -> (B, ngroups, k, headdim, d_state) + ssm_state_r = rearrange(ssm_state, "b (g k) p n -> b g k p n", g=self.ngroups, k=k) + + # SSM recurrence: h_new = dA * h_old + dB * x + # dB = dt * B + # dB_scaled_by_dt: (B, ngroups, k, d_state) + dB_scaled_by_dt = torch.einsum("bgk,bgn->bgkn", dt_r, B_r) + # dBx: (B, ngroups, k, headdim, d_state) + dBx = torch.einsum("bgkp,bgkn->bgkpn", x_r, dB_scaled_by_dt) + + ssm_state_new_r = dA_r.unsqueeze(-1).unsqueeze(-1) * ssm_state_r + dBx + ssm_state.copy_(rearrange(ssm_state_new_r, "b g k p n -> b (g k) p n")) + + # Output: y = C * h_new + D * x + # y_interim: (B, ngroups, k, headdim) + y_interim = torch.einsum("bgkpn,bgn->bgkp", ssm_state_new_r.to(dtype), C_r) + + D_param = self.D.to(dtype) + if self.D_has_hdim: # D is (d_ssm) = (nheads * headdim) + D_r = rearrange(D_param, "(g k p) -> g k p", g=self.ngroups, k=k, p=self.headdim) + y_r = y_interim + D_r.unsqueeze(0) * x_r + else: # D is (nheads) + D_r = rearrange(D_param, "(g k) -> g k", g=self.ngroups, k=k) + y_r = y_interim + D_r.unsqueeze(0).unsqueeze(-1) * x_r + + y = rearrange(y_r, "b g k p -> b (g k p)") # (B, d_ssm) + if not self.rmsnorm: y = y * self.act(z) # (B D) else: @@ -376,8 +431,26 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: + # Handle batch size changes or explicit initialization + if initialize_states or conv_state.shape[0] != batch_size or ssm_state.shape[0] != batch_size: + # Re-initialize states if batch size changed or if explicitly requested + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], # out_channels + self.d_conv, # kernel_size + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + elif initialize_states: # Original condition if batch sizes matched but re-init was true conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359..162ca3b7 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -173,7 +173,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta if y.is_complex(): y = y.real * 2 ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) + y = torch.stack(ys, dim=2) # (batch, dim, L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) @@ -385,7 +385,8 @@ def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, + b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-6 ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] @@ -399,21 +400,39 @@ def mamba_inner_ref( x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) + + if dt_rms_weight is not None: + delta_reshaped = rearrange(delta, "b d l -> (b l) d").contiguous() + delta_reshaped = rms_norm_forward(delta_reshaped, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps) + delta = rearrange(delta_reshaped, "(b l) d -> b d l", l=L).contiguous() + if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): - B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b dstate l", l=L) + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2) + if b_rms_weight is not None: + B_reshaped = rearrange(B, "b dstate l -> (b l) dstate").contiguous() + B_reshaped = rms_norm_forward(B_reshaped, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) + B = rearrange(B_reshaped, "(b l) dstate -> b dstate l", l=L).contiguous() else: - B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + B = B.contiguous() # Ensure contiguity if not already handled by RMSNorm path if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): - C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=L) + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2) + if c_rms_weight is not None: + C_reshaped = rearrange(C, "b dstate l -> (b l) dstate").contiguous() + C_reshaped = rms_norm_forward(C_reshaped, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) + C = rearrange(C_reshaped, "(b l) dstate -> b dstate l", l=L).contiguous() else: - C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + C = C.contiguous() # Ensure contiguity if not already handled by RMSNorm path y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/mamba_ssm/utils/hf.py b/mamba_ssm/utils/hf.py index 0d7555ac..778bd185 100644 --- a/mamba_ssm/utils/hf.py +++ b/mamba_ssm/utils/hf.py @@ -15,7 +15,7 @@ def load_state_dict_hf(model_name, device=None, dtype=None): # If not fp32, then we don't want to load directly to the GPU mapped_device = "cpu" if dtype not in [torch.float32, None] else device resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - return torch.load(resolved_archive_file, map_location=mapped_device) + state_dict = torch.load(resolved_archive_file, map_location=mapped_device) # Convert dtype before moving to GPU to save memory if dtype is not None: state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} diff --git a/setup.py b/setup.py index f61ca90d..26bd83f3 100755 --- a/setup.py +++ b/setup.py @@ -99,27 +99,28 @@ def get_torch_hip_version(): return None -def check_if_hip_home_none(global_option: str) -> None: - - if HIP_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary - # in that case. +def check_if_hip_home_none(global_option: str): + if HIP_HOME is None: + raise RuntimeError( + f"{global_option} was requested, but the ROCm/HIP installation is incomplete. " + 'Please make sure ROCm is properly installed and HIP_HOME environment variable is set.\n' + 'On Ubuntu, you may need to install: rocm-libs hipcc hiprt hipcub rocprim rocrand rocthrust rocblas hipblas rocsolver hipsparse rocsparse hipfft rocfft rocthrust rocrand' + ) warnings.warn( f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?" ) def check_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary - # in that case. - warnings.warn( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) + if CUDA_HOME is None: + raise RuntimeError( + f"{global_option} was requested, but CUDA installation was not found. " + 'Please ensure CUDA is properly installed and the CUDA_HOME environment variable is set.\n' + 'Common solutions include:\n' + '1. Install CUDA from NVIDIA: https://developer.nvidia.com/cuda-downloads\n' + '2. Set CUDA_HOME to your CUDA installation directory (e.g., /usr/local/cuda-11.8)\n' + '3. Add CUDA to your PATH: export PATH=$PATH:$CUDA_HOME/bin' + ) def append_nvcc_threads(nvcc_extra_args): @@ -158,8 +159,6 @@ def append_nvcc_threads(nvcc_extra_args): UserWarning ) - cc_flag.append("-DBUILD_PYTHON_PACKAGE") - else: check_if_cuda_home_none(PACKAGE_NAME) # Check, if CUDA11 is installed for compute capability 8.0