@@ -112,7 +112,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
112112 weight_t *C = reinterpret_cast <weight_t *>(params.C_ptr ) + dim_id * kNRows * params.C_d_stride ;
113113 input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + batch_id * params.C_batch_stride + group_id * params.C_group_stride ;
114114 scan_t *x = reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id * kNRows ) * params.n_chunks * params.dstate ;
115- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride ;
115+
116+ // Load cu_seqlens into shared memory
117+ const int cu_seqlens_size = params.cu_seqlens_size ;
118+ long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr );
119+ __shared__ long smem_cu_seqlens[1024 ]; // Adjust size as needed
120+ for (int i = threadIdx .x ; i < cu_seqlens_size; i += blockDim .x ) {
121+ smem_cu_seqlens[i] = cu_seqlens[i];
122+ }
123+ __syncthreads ();
124+
116125
117126 float D_val[kNRows ] = {0 };
118127 if (params.D_ptr != nullptr ) {
@@ -224,15 +233,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
224233
225234 // Reset A bar for cumulative sequences (Real)
226235 int left = 1 ;
227- int right = params.cu_seqlens_size - 2 ;
236+ int right = cu_seqlens_size - 2 ;
237+ int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
228238 while (left <= right) {
229- if (cu_seqlens[(left + right) >> 1 ] == threadIdx .x * kNItems + i + chunk * kChunkSize ) {
239+ int mid = (left + right) >> 1 ;
240+ if (smem_cu_seqlens[mid] == idx) {
230241 thread_data[i].x = 0 .f ;
231242 break ;
232- } else if (cu_seqlens[(left + right) >> 1 ] < threadIdx . x * kNItems + i + chunk * kChunkSize ) {
233- left = ((left + right) >> 1 ) + 1 ;
243+ } else if (smem_cu_seqlens[mid ] < idx ) {
244+ left = mid + 1 ;
234245 } else {
235- right = ((left + right) >> 1 ) - 1 ;
246+ right = mid - 1 ;
236247 }
237248 }
238249
@@ -249,19 +260,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
249260
250261 // Reset A bar for cumulative sequences (Complex)
251262 int left = 1 ;
252- int right = params.cu_seqlens_size - 2 ;
263+ int right = cu_seqlens_size - 2 ;
264+ int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
253265 while (left <= right) {
254- if (cu_seqlens[(left + right) >> 1 ] == threadIdx .x * kNItems + i + chunk * kChunkSize ) {
266+ int mid = (left + right) >> 1 ;
267+ if (smem_cu_seqlens[mid] == idx) {
255268 thread_data[i].x = 0 .f ;
256269 thread_data[i].y = 0 .f ;
257270 break ;
258- } else if (cu_seqlens[(left + right) >> 1 ] < threadIdx . x * kNItems + i + chunk * kChunkSize ) {
259- left = ((left + right) >> 1 ) + 1 ;
271+ } else if (smem_cu_seqlens[mid ] < idx ) {
272+ left = mid + 1 ;
260273 } else {
261- right = ((left + right) >> 1 ) - 1 ;
274+ right = mid - 1 ;
262275 }
263276 }
264277
278+
265279 if constexpr (!Ktraits::kIsEvenLen ) { // So that the last state is correct
266280 if (threadIdx .x * kNItems + i >= params.seqlen - chunk * kChunkSize ) {
267281 thread_data[i] = make_float4 (1 .f , 0 .f , 0 .f , 0 .f );
0 commit comments