11import torch
2- import torch .nn .functional as F
32from einops import rearrange
43from typing import Optional , Tuple
54
6- import selective_scan_cuda
5+ from mamba_ssm . ops . selective_scan_interface import selective_scan_cuda
76
87
98@torch .library .custom_op (
109 "custom_ops::selective_scan_fwd" ,
1110 device_types = ["cuda" ],
1211 mutates_args = (),
12+ schema = "(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor, Tensor, Tensor, bool, bool, bool)" ,
1313)
1414def custom_selective_scan_fwd (
1515 u : torch .Tensor ,
@@ -22,28 +22,33 @@ def custom_selective_scan_fwd(
2222 delta_bias : Optional [torch .Tensor ],
2323 delta_softplus : bool ,
2424 return_last_state : bool ,
25- ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor , torch . Tensor , bool , bool , bool ] :
25+ ):
2626 pass
2727
28+
2829@torch .library .register_fake ("custom_ops::selective_scan_fwd" )
2930def custom_selective_scan_fwd_fake (
30- u ,
31- delta ,
32- A ,
33- B ,
34- C ,
35- D ,
36- z ,
37- delta_bias ,
38- delta_softplus ,
39- return_last_state ,
31+ u , delta , A , B , C , D , z , delta_bias , delta_softplus , return_last_state
4032):
41- final_out = torch .empty_like (u )
4233 dstate = A .size (1 ) * (2 if A .is_complex () else 1 )
43- last_state_fake = u .new_empty ((u .size (0 ), u .size (1 ), dstate )) if return_last_state else u .new_empty (0 )
44- out_fake = torch .empty_like (u )
45- x_fake = u .new_empty ((u .size (0 ), u .size (1 ), u .size (2 ), 2 * dstate ))
46- return final_out , last_state_fake , out_fake , x_fake , False , False , z is not None
34+ seqlen = u .size (2 )
35+ n_chunks = (seqlen + 2048 - 1 ) // 2048
36+
37+ squeeze_B = B .dim () == 3
38+ squeeze_C = C .dim () == 3
39+ has_z = z is not None
40+
41+ final_out = torch .empty_like (delta )
42+ out_fake = torch .empty_like (delta )
43+ last_state_fake = (
44+ u .new_empty ((u .size (0 ), u .size (1 ), dstate ))
45+ if return_last_state
46+ else u .new_empty (0 )
47+ )
48+ x_fake = u .new_empty ((u .size (0 ), u .size (1 ), n_chunks , 2 * A .size (1 )), dtype = A .dtype )
49+
50+ return final_out , last_state_fake , out_fake , x_fake , squeeze_B , squeeze_C , has_z
51+
4752
4853@torch .library .register_kernel ("custom_ops::selective_scan_fwd" , "cuda" )
4954def custom_selective_scan_fwd_cuda (
@@ -81,16 +86,23 @@ def custom_selective_scan_fwd_cuda(
8186 C = rearrange (C , "b dstate l -> b 1 dstate l" ).contiguous ()
8287 squeeze_C = True
8388
84- out , x , * rest = selective_scan_cuda .fwd (u , delta , A , B , C , D , z , delta_bias , delta_softplus )
89+ out , x , * rest = selective_scan_cuda .fwd (
90+ u , delta , A , B , C , D , z , delta_bias , delta_softplus
91+ )
8592 has_z = z is not None
86- final_out = rest [0 ].clone () if has_z else out .clone ()
93+ if has_z :
94+ final_out = rest [0 ].clone ()
95+ else :
96+ final_out = out .clone ()
8797 last_state = x [:, :, - 1 , 1 ::2 ].clone () if return_last_state else u .new_empty (0 )
8898 return final_out , last_state , out , x , squeeze_B , squeeze_C , has_z
8999
100+
90101@torch .library .custom_op (
91102 "custom_ops::selective_scan_bwd" ,
92103 device_types = ["cuda" ],
93104 mutates_args = (),
105+ schema = "(Tensor dout, Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, Tensor out, Tensor x, bool squeeze_B, bool squeeze_C, bool recompute_out_z) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, Tensor?)" ,
94106)
95107def custom_selective_scan_bwd (
96108 dout : torch .Tensor ,
@@ -107,9 +119,11 @@ def custom_selective_scan_bwd(
107119 x : torch .Tensor ,
108120 squeeze_B : bool ,
109121 squeeze_C : bool ,
110- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
122+ recompute_out_z : bool ,
123+ ):
111124 pass
112125
126+
113127@torch .library .register_fake ("custom_ops::selective_scan_bwd" )
114128def custom_selective_scan_bwd_fake (
115129 dout ,
@@ -126,16 +140,33 @@ def custom_selective_scan_bwd_fake(
126140 x ,
127141 squeeze_B ,
128142 squeeze_C ,
143+ recompute_out_z ,
129144):
145+ # Here we just return shape-compatible fake tensors
130146 du = torch .empty_like (u )
131147 ddelta = torch .empty_like (delta )
132148 dA = torch .empty_like (A )
133- dB = torch .empty_like (B )
134- dC = torch .empty_like (C )
135- dD = torch .empty_like (D ) if (D is not None and D .numel () > 0 ) else u .new_empty (0 )
136- dz = torch .empty_like (z ) if (z is not None and z .numel () > 0 ) else u .new_empty (0 )
137- ddelta_bias = torch .empty_like (delta_bias ) if (delta_bias is not None and delta_bias .numel () > 0 ) else u .new_empty (0 )
138- return du , ddelta , dA , dB , dC , dD , dz , ddelta_bias
149+
150+ # Decide if variable B/C
151+ is_variable_B = B .dim () > 3
152+ is_variable_C = C .dim () > 3
153+
154+ dB = torch .empty_like (
155+ B , dtype = B .dtype
156+ ) # If variable_B, still float32 is okay for fake
157+ dC = torch .empty_like (C , dtype = C .dtype )
158+
159+ dD = torch .empty_like (D ) if (D is not None ) else None
160+ ddelta_bias_out = torch .empty_like (delta_bias ) if (delta_bias is not None ) else None
161+ dz = torch .empty_like (z ) if (z is not None ) else None
162+
163+ if squeeze_B and dB .numel () > 0 :
164+ dB = dB .squeeze (1 )
165+ if squeeze_C and dC .numel () > 0 :
166+ dC = dC .squeeze (1 )
167+
168+ return du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz
169+
139170
140171@torch .library .register_kernel ("custom_ops::selective_scan_bwd" , "cuda" )
141172def custom_selective_scan_bwd_cuda (
@@ -153,68 +184,101 @@ def custom_selective_scan_bwd_cuda(
153184 x : torch .Tensor ,
154185 squeeze_B : bool ,
155186 squeeze_C : bool ,
187+ recompute_out_z : bool ,
156188):
157189 if dout .stride (- 1 ) != 1 :
158190 dout = dout .contiguous ()
159- B = B .contiguous ()
160- C = C .contiguous ()
161191
162192 results = selective_scan_cuda .bwd (
163- u , delta , A , B , C , D , z , delta_bias , dout , x , out , None , delta_softplus , False
193+ u ,
194+ delta ,
195+ A ,
196+ B ,
197+ C ,
198+ D ,
199+ z ,
200+ delta_bias ,
201+ dout ,
202+ x ,
203+ out ,
204+ None ,
205+ delta_softplus ,
206+ recompute_out_z ,
164207 )
208+
165209 has_z = z is not None
166210 if has_z :
167- du , ddelta , dA , dB , dC , dD , ddelta_bias , dz = results
211+ du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz = results
168212 else :
169- du , ddelta , dA , dB , dC , dD , ddelta_bias = results
170- dz = u . new_empty ( 0 )
213+ du , ddelta , dA , dB , dC , dD , ddelta_bias_out = results
214+ dz = None
171215
172216 if squeeze_B and dB .numel () > 0 :
173217 dB = dB .squeeze (1 )
174218 if squeeze_C and dC .numel () > 0 :
175219 dC = dC .squeeze (1 )
176220
177- return du , ddelta , dA , dB , dC , dD , dz , ddelta_bias
221+ return du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz
222+
178223
179224def custom_bridge (ctx , * grads ):
180225 dout = grads [0 ] if grads else ctx .saved_tensors [0 ].new_empty (0 )
181226 saved = ctx .saved_tensors
227+
182228 if not ctx .has_z :
183229 u , delta , A , B , C , D , delta_bias , x , out = saved
184230 z = None
185231 else :
186232 u , delta , A , B , C , D , z , delta_bias , x , out = saved
187233
188- du , ddelta , dA , dB , dC , dD , dz , ddelta_bias = torch .ops .custom_ops .selective_scan_bwd (
189- dout ,
190- u ,
191- delta ,
192- A ,
193- B ,
194- C ,
195- D ,
196- z ,
197- delta_bias ,
198- ctx .delta_softplus ,
199- out ,
200- x ,
201- ctx .squeeze_B ,
202- ctx .squeeze_C
234+ du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz = (
235+ torch .ops .custom_ops .selective_scan_bwd (
236+ dout ,
237+ u ,
238+ delta ,
239+ A ,
240+ B ,
241+ C ,
242+ D ,
243+ z ,
244+ delta_bias ,
245+ ctx .delta_softplus ,
246+ out ,
247+ x ,
248+ ctx .squeeze_B ,
249+ ctx .squeeze_C ,
250+ False ,
251+ )
203252 )
204253
254+ # For optional inputs, return None if not provided in forward
255+ if D is None :
256+ dD = None
257+ if z is None :
258+ dz = None
259+ if delta_bias is None :
260+ ddelta_bias_out = None
261+
262+ # Return gradients in the order of forward inputs:
263+ # (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
264+ # `delta_softplus` and `return_last_state` are bools -> gradient = None
265+ d_delta_softplus = None
266+ d_return_last_state = None
267+
205268 return (
206269 du ,
207270 ddelta ,
208271 dA ,
209272 dB ,
210273 dC ,
211- dD if D is not None else None ,
212- dz if z is not None else None ,
213- ddelta_bias if delta_bias is not None else None ,
214- None ,
215- None ,
274+ dD ,
275+ dz ,
276+ ddelta_bias_out ,
277+ d_delta_softplus ,
278+ d_return_last_state ,
216279 )
217280
281+
218282def custom_setup_context (ctx , inputs , output ):
219283 (u , delta , A , B , C , D , z , delta_bias , delta_softplus , return_last_state ) = inputs
220284 (final_out , last_state , out , x , squeeze_B , squeeze_C , has_z ) = output
@@ -236,10 +300,12 @@ def custom_setup_context(ctx, inputs, output):
236300 else :
237301 ctx .save_for_backward (u , delta , A , B , C , D , z , delta_bias , x , out )
238302
303+
239304torch .library .register_autograd (
240305 "custom_ops::selective_scan_fwd" , custom_bridge , setup_context = custom_setup_context
241306)
242307
308+
243309def selective_scan_fn_custom_op (
244310 u : torch .Tensor ,
245311 delta : torch .Tensor ,
@@ -252,20 +318,9 @@ def selective_scan_fn_custom_op(
252318 delta_softplus : bool ,
253319 return_last_state : bool ,
254320) -> torch .Tensor :
255- # Pass all arguments positionally, exactly in schema order:
256321 final_out , last_state , _ , _ , _ , _ , _ = torch .ops .custom_ops .selective_scan_fwd (
257- u ,
258- delta ,
259- A ,
260- B ,
261- C ,
262- D ,
263- z ,
264- delta_bias ,
265- delta_softplus ,
266- return_last_state
322+ u , delta , A , B , C , D , z , delta_bias , delta_softplus , return_last_state
267323 )
268-
269324 if return_last_state :
270325 return final_out , last_state
271326 else :
0 commit comments