Skip to content

TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were givenΒ #609

@saurabh-kataria

Description

@saurabh-kataria

I am unable to use the sample Mamba2 code. Even with following simple code, it fails to do forward pass.

import torch
from mamba_ssm import Mamba2
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(d_model=dim, d_state=64, d_conv=4, expand=2).to('cuda')
y = model(x)

--------------------------------------------------------------------------- 09:19:57 [2/4999]
TypeError Traceback (most recent call last)
Cell In[5], line 5
3 x = torch.randn(batch, length, dim).to("cuda")
4 model = Mamba2(d_model=dim, d_state=64, d_conv=4, expand=2).to('cuda')
----> 5 y = model(x)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py:185, in Mamba2.forward(self, u, seqlen, seq_idx, cu_seqlens, inference_params)
183 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
184 if self.use_mem_eff_path and inference_params is None:
--> 185 out = mamba_split_conv1d_scan_combined(
186 zxbcdt,
187 rearrange(self.conv1d.weight, "d 1 w -> d w"),
188 self.conv1d.bias,
189 self.dt_bias,
190 A,
191 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
192 chunk_size=self.chunk_size,
193 seq_idx=seq_idx,
194 activation=self.activation,
195 rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
196 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
197 outproj_weight=self.out_proj.weight,
198 outproj_bias=self.out_proj.bias,
199 headdim=None if self.D_has_hdim else self.headdim,
200 ngroups=self.ngroups,
201 norm_before_gate=self.norm_before_gate,
202 **dt_limit_kwargs,
203 )
204 if seqlen_og is not None:
205 out = rearrange(out, "b l d -> (b l) d")

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:930, in mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_w
eight, outproj_bias, headdim, ngroups, norm_before_gate)
911 def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroup
s=1, norm_before_gate=True):
912 """
913 Argument:
914 zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
(...)
928 out: (batch, seqlen, dim)
929 """
--> 930 return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )

TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions