66
77from einops import rearrange , repeat
88
9- from causal_conv1d import causal_conv1d_fn
10- import causal_conv1d_cuda
9+ try :
10+ from causal_conv1d import causal_conv1d_fn
11+ import causal_conv1d_cuda
12+ except ImportError :
13+ causal_conv1d_fn = None
14+ causal_conv1d_cuda = None
15+
1116import selective_scan_cuda
1217
1318
@@ -168,6 +173,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
168173 """
169174 xz: (batch, dim, seqlen)
170175 """
176+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
171177 assert checkpoint_lvl in [0 , 1 ]
172178 L = xz .shape [- 1 ]
173179 delta_rank = delta_proj_weight .shape [1 ]
@@ -196,7 +202,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
196202 assert x .shape [2 ] == (d_conv - 1 ) * len (cu_seqlens [1 :- 1 ]) + z .shape [2 ]
197203
198204 conv1d_bias = conv1d_bias .contiguous () if conv1d_bias is not None else None
199- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
205+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , None , None , True )
200206
201207 # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
202208 if cu_seqlens is not None :
@@ -262,6 +268,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
262268 @custom_bwd
263269 def backward (ctx , dout ):
264270 # dout: (batch, seqlen, dim)
271+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
265272 (xz , conv1d_weight , conv1d_bias , x_dbl , x_proj_weight , delta_proj_weight , out_proj_weight ,
266273 conv1d_out , delta , A , B , C , D , delta_bias , scan_intermediates , out , cu_seqlens ) = ctx .saved_tensors
267274 L = xz .shape [- 1 ]
@@ -285,7 +292,7 @@ def backward(ctx, dout):
285292 x = padded_x
286293 assert x .shape [2 ] == (d_conv - 1 ) * len (cu_seqlens [1 :- 1 ]) + z .shape [2 ]
287294
288- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
295+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , None , None , True )
289296
290297 # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
291298 if cu_seqlens is not None :
@@ -345,8 +352,8 @@ def backward(ctx, dout):
345352 dconv1d_out = rearrange (dconv1d_out , "d (b l) -> b d l" , b = x .shape [0 ], l = x .shape [- 1 ])
346353 # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
347354 # backward of conv1d with the backward of chunk).
348- dx , dconv1d_weight , dconv1d_bias = causal_conv1d_cuda .causal_conv1d_bwd (
349- x , conv1d_weight , conv1d_bias , dconv1d_out , None , dx , True
355+ dx , dconv1d_weight , dconv1d_bias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
356+ x , conv1d_weight , conv1d_bias , dconv1d_out , None , None , None , dx , False , True
350357 )
351358 dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
352359 dconv1d_weight = rearrange (dconv1d_weight , "d w -> d 1 w" )
@@ -374,11 +381,12 @@ def mamba_inner_ref(
374381 A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
375382 C_proj_bias = None , delta_softplus = True
376383):
384+ assert causal_conv1d_fn is not None , "causal_conv1d_fn is not available. Please install causal-conv1d."
377385 L = xz .shape [- 1 ]
378386 delta_rank = delta_proj_weight .shape [1 ]
379387 d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
380388 x , z = xz .chunk (2 , dim = 1 )
381- x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , "silu" )
389+ x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , activation = "silu" )
382390 # We're being very careful here about the layout, to avoid extra transposes.
383391 # We want delta to have d as the slowest moving dimension
384392 # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
0 commit comments