|
24 | 24 | except ImportError: |
25 | 25 | causal_conv1d_fn, causal_conv1d_cuda = None, None |
26 | 26 |
|
27 | | -from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd |
28 | | -from src.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd |
29 | | -from src.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db |
30 | | -from src.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable |
31 | | -from src.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref |
32 | | -from src.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd |
33 | | -from src.ops.triton.ssd_state_passing import state_passing, state_passing_ref |
34 | | -from src.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates |
35 | | -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb |
36 | | -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable |
37 | | -from src.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref |
38 | | -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev |
39 | | -from src.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd |
40 | | -from src.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd |
| 27 | +from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd |
| 28 | +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd |
| 29 | +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db |
| 30 | +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable |
| 31 | +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref |
| 32 | +from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd |
| 33 | +from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref |
| 34 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates |
| 35 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb |
| 36 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable |
| 37 | +from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref |
| 38 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev |
| 39 | +from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd |
| 40 | +from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd |
41 | 41 |
|
42 | 42 | TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') |
43 | 43 |
|
@@ -651,7 +651,7 @@ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus |
651 | 651 | Return: |
652 | 652 | out: (batch, seqlen, nheads, headdim) |
653 | 653 | """ |
654 | | - from src.ops.selective_scan_interface import selective_scan_fn |
| 654 | + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn |
655 | 655 |
|
656 | 656 | batch, seqlen, nheads, headdim = x.shape |
657 | 657 | _, _, ngroups, dstate = B.shape |
|
0 commit comments