Skip to content

Commit e61f105

Browse files
committed
Change launch warp size back for CUDA
1 parent bed7e5d commit e61f105

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,11 @@ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
543543

544544
if (warp_size == 32) {
545545
if (params.seqlen <= 128) {
546-
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
546+
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
547547
} else if (params.seqlen <= 256) {
548-
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
548+
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
549549
} else if (params.seqlen <= 512) {
550-
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
550+
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
551551
} else if (params.seqlen <= 1024) {
552552
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
553553
} else {

csrc/selective_scan/selective_scan_fwd_kernel.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,11 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
358358

359359
if (warp_size == 32) {
360360
if (params.seqlen <= 128) {
361-
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
361+
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
362362
} else if (params.seqlen <= 256) {
363-
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
363+
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
364364
} else if (params.seqlen <= 512) {
365-
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
365+
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
366366
} else if (params.seqlen <= 1024) {
367367
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
368368
} else {

0 commit comments

Comments
 (0)