@@ -79,7 +79,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
7979 void * delta_bias_ptr,
8080 void * x_ptr,
8181 bool has_z,
82- bool delta_softplus) {
82+ bool delta_softplus,
83+ void * cu_seqlens_ptr,
84+ const int cu_seqlens_size) {
8385
8486 // Reset the parameters
8587 memset (¶ms, 0 , sizeof (params));
@@ -109,6 +111,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
109111 params.x_ptr = x_ptr;
110112 params.z_ptr = has_z ? z.data_ptr () : nullptr ;
111113 params.out_z_ptr = has_z ? out_z.data_ptr () : nullptr ;
114+
115+ params.cu_seqlens_ptr = cu_seqlens_ptr;
116+ params.cu_seqlens_size = cu_seqlens_size;
117+
112118 // All stride are in elements, not bytes.
113119 params.A_d_stride = A.stride (0 );
114120 params.A_dstate_stride = A.stride (1 );
@@ -173,15 +179,17 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms,
173179 void * ddelta_bias_ptr,
174180 bool has_z,
175181 bool delta_softplus,
176- bool recompute_out_z) {
182+ bool recompute_out_z,
183+ void * cu_seqlens_ptr,
184+ const int cu_seqlens_size) {
177185 // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178186 set_ssm_params_fwd (params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179187 u, delta, A, B, C, has_z ? out : dout,
180188 has_z ? z : dout,
181189 // If not recompute_out_z, pass dout instead of out_z.
182190 // This won't be used by the bwd kernel
183191 recompute_out_z ? out_z : dout,
184- D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
192+ D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, cu_seqlens_ptr, cu_seqlens_size );
185193 if (!recompute_out_z) { params.out_z_ptr = nullptr ; }
186194
187195 // Set the pointers and strides.
@@ -229,7 +237,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
229237 const c10::optional<at::Tensor> &D_,
230238 const c10::optional<at::Tensor> &z_,
231239 const c10::optional<at::Tensor> &delta_bias_,
232- bool delta_softplus) {
240+ bool delta_softplus,
241+ const c10::optional<at::Tensor> &cu_seqlens_) {
233242 auto input_type = u.scalar_type ();
234243 auto weight_type = A.scalar_type ();
235244 TORCH_CHECK (input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -319,7 +328,9 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
319328 delta_bias_.has_value () ? delta_bias_.value ().data_ptr () : nullptr ,
320329 x.data_ptr (),
321330 has_z,
322- delta_softplus);
331+ delta_softplus,
332+ cu_seqlens_.has_value () ? cu_seqlens_.value ().data_ptr () : nullptr ,
333+ cu_seqlens_.has_value () ? cu_seqlens_.value ().size (0 ) : 0 );
323334
324335 // Otherwise the kernel will be launched from cuda:0 device
325336 // Cast to char to avoid compiler warning about narrowing
@@ -346,7 +357,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
346357 const c10::optional<at::Tensor> &out_,
347358 c10::optional<at::Tensor> &dz_,
348359 bool delta_softplus,
349- bool recompute_out_z) {
360+ bool recompute_out_z,
361+ const c10::optional<at::Tensor> &cu_seqlens_) {
350362 auto input_type = u.scalar_type ();
351363 auto weight_type = A.scalar_type ();
352364 TORCH_CHECK (input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -474,7 +486,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
474486 dout, du, ddelta, dA, dB, dC, dz,
475487 D_.has_value () ? dD.data_ptr () : nullptr ,
476488 delta_bias_.has_value () ? ddelta_bias.data_ptr () : nullptr ,
477- has_z, delta_softplus, recompute_out_z);
489+ has_z, delta_softplus, recompute_out_z,
490+ cu_seqlens_.has_value () ? cu_seqlens_.value ().data_ptr () : nullptr ,
491+ cu_seqlens_.has_value () ? cu_seqlens_.value ().size (0 ) : 0 );
478492
479493 // Otherwise the kernel will be launched from cuda:0 device
480494 // Cast to char to avoid compiler warning about narrowing
0 commit comments