@@ -113,15 +113,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
113113 input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + batch_id * params.C_batch_stride + group_id * params.C_group_stride ;
114114 scan_t *x = reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id * kNRows ) * params.n_chunks * params.dstate ;
115115
116- // Load cu_seqlens into shared memory
117116 const int cu_seqlens_size = params.cu_seqlens_size ;
118- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr );
119- __shared__ long smem_cu_seqlens[1024 ]; // Adjust size as needed
120- for (int i = threadIdx .x ; i < cu_seqlens_size; i += blockDim .x ) {
121- smem_cu_seqlens[i] = cu_seqlens[i];
122- }
123- __syncthreads ();
124-
117+ const long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr );
125118
126119 float D_val[kNRows ] = {0 };
127120 if (params.D_ptr != nullptr ) {
@@ -237,10 +230,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
237230 int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
238231 while (left <= right) {
239232 int mid = (left + right) >> 1 ;
240- if (smem_cu_seqlens [mid] == idx) {
233+ if (cu_seqlens [mid] == idx) {
241234 thread_data[i].x = 0 .f ;
242235 break ;
243- } else if (smem_cu_seqlens [mid] < idx) {
236+ } else if (cu_seqlens [mid] < idx) {
244237 left = mid + 1 ;
245238 } else {
246239 right = mid - 1 ;
@@ -264,11 +257,11 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
264257 int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
265258 while (left <= right) {
266259 int mid = (left + right) >> 1 ;
267- if (smem_cu_seqlens [mid] == idx) {
260+ if (cu_seqlens [mid] == idx) {
268261 thread_data[i].x = 0 .f ;
269262 thread_data[i].y = 0 .f ;
270263 break ;
271- } else if (smem_cu_seqlens [mid] < idx) {
264+ } else if (cu_seqlens [mid] < idx) {
272265 left = mid + 1 ;
273266 } else {
274267 right = mid - 1 ;
0 commit comments