99#include < c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
1010#include < ATen/cuda/Atomic.cuh> // For atomicAdd on complex
1111
12- #include < cub/block/block_load.cuh>
13- #include < cub/block/block_store.cuh>
14- #include < cub/block/block_scan.cuh>
15- #include < cub/block/block_reduce.cuh>
12+ #ifndef USE_ROCM
13+ #include < cub/block/block_load.cuh>
14+ #include < cub/block/block_store.cuh>
15+ #include < cub/block/block_scan.cuh>
16+ #include < cub/block/block_reduce.cuh>
17+ #else
18+ #include < hipcub/hipcub.hpp>
19+ namespace cub = hipcub;
20+ #endif
1621
1722#include " selective_scan.h"
1823#include " selective_scan_common.h"
@@ -33,7 +38,7 @@ struct Selective_Scan_bwd_kernel_traits {
3338 static constexpr int kNItems = kNItems_ ;
3439 static constexpr int kNBytes = sizeof (input_t );
3540 static_assert (kNBytes == 2 || kNBytes == 4 );
36- static constexpr int kNElts = kNBytes == 4 ? 4 : std::min (8 , kNItems );
41+ static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min (8 , kNItems );
3742 static_assert (kNItems % kNElts == 0 );
3843 static constexpr int kNLoads = kNItems / kNElts ;
3944 static constexpr bool kIsComplex = std::is_same_v<weight_t , complex_t >;
@@ -61,12 +66,13 @@ struct Selective_Scan_bwd_kernel_traits {
6166 using BlockReduceFloatT = cub::BlockReduce<float , kNThreads >;
6267 using BlockReduceComplexT = cub::BlockReduce<complex_t , kNThreads >;
6368 using BlockExchangeT = cub::BlockExchange<float , kNThreads , !kIsComplex ? kNItems : kNItems * 2 >;
64- static constexpr int kSmemIOSize = std::max({sizeof (typename BlockLoadT::TempStorage),
65- sizeof (typename BlockLoadVecT::TempStorage),
66- (int (kIsVariableB ) + int (kIsVariableC )) * sizeof (typename BlockLoadWeightT::TempStorage),
67- (int (kIsVariableB ) + int (kIsVariableC )) * sizeof (typename BlockLoadWeightVecT::TempStorage),
68- sizeof (typename BlockStoreT::TempStorage),
69- sizeof (typename BlockStoreVecT::TempStorage)});
69+
70+ static constexpr int kSmemIOSize = custom_max({sizeof (typename BlockLoadT::TempStorage),
71+ sizeof (typename BlockLoadVecT::TempStorage),
72+ (int (kIsVariableB ) + int (kIsVariableC )) * sizeof (typename BlockLoadWeightT::TempStorage),
73+ (int (kIsVariableB ) + int (kIsVariableC )) * sizeof (typename BlockLoadWeightVecT::TempStorage),
74+ sizeof (typename BlockStoreT::TempStorage),
75+ sizeof (typename BlockStoreVecT::TempStorage)});
7076 static constexpr int kSmemExchangeSize = (int (kIsVariableB ) + int (kIsVariableC )) * sizeof (typename BlockExchangeT::TempStorage);
7177 static constexpr int kSmemReduceSize = sizeof (typename BlockReduceT::TempStorage);
7278 static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof (typename BlockScanT::TempStorage) + sizeof (typename BlockReverseScanT::TempStorage);
@@ -263,12 +269,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
263269 // Initialize running total
264270 scan_t running_prefix = chunk > 0 && threadIdx .x % 32 == 0 ? x[(chunk - 1 ) * params.dstate + state_idx] : make_float2 (1 .f , 0 .f );
265271 SSMScanPrefixCallbackOp<weight_t > prefix_op (running_prefix);
266- Ktraits::BlockScanT (smem_scan).InclusiveScan (
272+ typename Ktraits::BlockScanT (smem_scan).InclusiveScan (
267273 thread_data, thread_data, SSMScanOp<weight_t >(), prefix_op
268274 );
269275 scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx .x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2 (1 .f , 0 .f );
270276 SSMScanPrefixCallbackOp<weight_t > postfix_op (running_postfix);
271- Ktraits::BlockReverseScanT (smem_reverse_scan).InclusiveReverseScan (
277+ typename Ktraits::BlockReverseScanT (smem_reverse_scan).InclusiveReverseScan (
272278 thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t >(), postfix_op
273279 );
274280 if (threadIdx .x == 0 ) { smem_running_postfix[state_idx] = postfix_op.running_prefix ; }
@@ -297,11 +303,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
297303 // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
298304 if constexpr (kIsVariableB || kIsVariableC ) {
299305 if constexpr (kIsVariableB ) {
300- Ktraits::BlockExchangeT (smem_exchange).BlockedToStriped (dB_vals, dB_vals);
306+ typename Ktraits::BlockExchangeT (smem_exchange).BlockedToStriped (dB_vals, dB_vals);
301307 }
302308 if constexpr (kIsVariableC ) {
303309 auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
304- Ktraits::BlockExchangeT (smem_exchange_C).BlockedToStriped (dC_vals, dC_vals);
310+ typename Ktraits::BlockExchangeT (smem_exchange_C).BlockedToStriped (dC_vals, dC_vals);
305311 }
306312 const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx .x ;
307313 weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx .x ;
@@ -316,13 +322,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
316322 }
317323 if constexpr (!kIsVariableB || !kIsVariableC ) {
318324 float2 dA_dBC_val = make_float2 (dA_val, dBC_val);
319- dA_dBC_val = Ktraits::BlockReduceT (smem_reduce).Sum (dA_dBC_val);
325+ dA_dBC_val = typename Ktraits::BlockReduceT (smem_reduce).Sum (dA_dBC_val);
320326 dA_val = dA_dBC_val.x ;
321327 if (threadIdx .x == 0 ) {
322328 smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
323329 }
324330 } else {
325- dA_val = Ktraits::BlockReduceFloatT (smem_reduce_float).Sum (dA_val);
331+ dA_val = typename Ktraits::BlockReduceFloatT (smem_reduce_float).Sum (dA_val);
326332 }
327333 if (threadIdx .x == 0 ) {
328334 smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
@@ -356,12 +362,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
356362 // Initialize running total
357363 scan_t running_prefix = chunk > 0 && threadIdx .x % 32 == 0 ? x[(chunk - 1 ) * params.dstate + state_idx] : make_float4 (1 .f , 0 .f , 0 .f , 0 .f );
358364 SSMScanPrefixCallbackOp<weight_t > prefix_op (running_prefix);
359- Ktraits::BlockScanT (smem_scan).InclusiveScan (
365+ typename Ktraits::BlockScanT (smem_scan).InclusiveScan (
360366 thread_data, thread_data, SSMScanOp<weight_t >(), prefix_op
361367 );
362368 scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx .x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4 (1 .f , 0 .f , 0 .f , 0 .f );
363369 SSMScanPrefixCallbackOp<weight_t > postfix_op (running_postfix);
364- Ktraits::BlockReverseScanT (smem_reverse_scan).InclusiveReverseScan (
370+ typename Ktraits::BlockReverseScanT (smem_reverse_scan).InclusiveReverseScan (
365371 thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t >(), postfix_op
366372 );
367373 if (threadIdx .x == 0 ) { smem_running_postfix[state_idx] = postfix_op.running_prefix ; }
@@ -397,7 +403,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
397403 dB_vals_f[i * 2 ] = dB_vals[i].real_ ;
398404 dB_vals_f[i * 2 + 1 ] = dB_vals[i].imag_ ;
399405 }
400- Ktraits::BlockExchangeT (smem_exchange).BlockedToStriped (dB_vals_f, dB_vals_f);
406+ typename Ktraits::BlockExchangeT (smem_exchange).BlockedToStriped (dB_vals_f, dB_vals_f);
401407 }
402408 if constexpr (kIsVariableC ) {
403409 #pragma unroll
@@ -406,7 +412,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
406412 dC_vals_f[i * 2 + 1 ] = dC_vals[i].imag_ ;
407413 }
408414 auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
409- Ktraits::BlockExchangeT (smem_exchange_C).BlockedToStriped (dC_vals_f, dC_vals_f);
415+ typename Ktraits::BlockExchangeT (smem_exchange_C).BlockedToStriped (dC_vals_f, dC_vals_f);
410416 }
411417 const int seqlen_remaining = (params.seqlen - chunk * kChunkSize ) * 2 - threadIdx .x ;
412418 float *dB_cur = reinterpret_cast <float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx .x ;
@@ -421,14 +427,14 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
421427 }
422428 if constexpr (!kIsVariableB || !kIsVariableC ) {
423429 float4 dA_dBC_val = make_float4 (dA_val.real_ , dA_val.imag_ , dBC_val.real_ , dBC_val.imag_ );
424- dA_dBC_val = Ktraits::BlockReduceT (smem_reduce).Sum (dA_dBC_val);
430+ dA_dBC_val = typename Ktraits::BlockReduceT (smem_reduce).Sum (dA_dBC_val);
425431 dA_val = complex_t (dA_dBC_val.x , dA_dBC_val.y );
426432 dBC_val = complex_t (dA_dBC_val.z , dA_dBC_val.w );
427433 if (threadIdx .x == 0 ) {
428434 smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
429435 }
430436 } else {
431- dA_val = Ktraits::BlockReduceComplexT (smem_reduce_complex).Sum (dA_val);
437+ dA_val = typename Ktraits::BlockReduceComplexT (smem_reduce_complex).Sum (dA_val);
432438 }
433439 if (threadIdx .x == 0 ) {
434440 smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
@@ -465,12 +471,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
465471 Cvar -= kChunkSize * (!kIsComplex ? 1 : 2 );
466472 }
467473 if (params.dD_ptr != nullptr ) {
468- dD_val = Ktraits::BlockReduceFloatT (smem_reduce_float).Sum (dD_val);
474+ dD_val = typename Ktraits::BlockReduceFloatT (smem_reduce_float).Sum (dD_val);
469475 if (threadIdx .x == 0 ) { gpuAtomicAdd (dD, dD_val); }
470476 }
471477 if (params.ddelta_bias_ptr != nullptr ) {
472478 __syncthreads ();
473- ddelta_bias_val = Ktraits::BlockReduceFloatT (smem_reduce_float).Sum (ddelta_bias_val);
479+ ddelta_bias_val = typename Ktraits::BlockReduceFloatT (smem_reduce_float).Sum (ddelta_bias_val);
474480 if (threadIdx .x == 0 ) { gpuAtomicAdd (ddelta_bias, ddelta_bias_val); }
475481 }
476482 for (int state_idx = threadIdx .x ; state_idx < params.dstate ; state_idx += blockDim .x ) {
@@ -499,13 +505,24 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
499505 // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
500506 // TODO: check this
501507 constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof (typename Ktraits::scan_t ) + (kNThreads + 4 * MAX_DSTATE) * sizeof (typename Ktraits::weight_t );
502- // printf("smem_size = %d\n", kSmemSize);
508+
503509 dim3 grid (params.batch , params.dim );
510+
504511 auto kernel = &selective_scan_bwd_kernel<Ktraits>;
512+
505513 if (kSmemSize >= 48 * 1024 ) {
514+
515+ #ifndef USE_ROCM
506516 C10_CUDA_CHECK (cudaFuncSetAttribute (
507517 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
518+ #else
519+ C10_CUDA_CHECK (cudaFuncSetAttribute (
520+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
521+ std::cerr << " Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n " << std::endl;
522+ #endif
523+
508524 }
525+
509526 kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
510527 C10_CUDA_KERNEL_LAUNCH_CHECK ();
511528 });
@@ -517,15 +534,37 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
517534
518535template <typename input_t , typename weight_t >
519536void selective_scan_bwd_cuda (SSMParamsBwd ¶ms, cudaStream_t stream) {
520- if (params.seqlen <= 128 ) {
521- selective_scan_bwd_launch<32 , 4 , input_t , weight_t >(params, stream);
522- } else if (params.seqlen <= 256 ) {
523- selective_scan_bwd_launch<32 , 8 , input_t , weight_t >(params, stream);
524- } else if (params.seqlen <= 512 ) {
525- selective_scan_bwd_launch<32 , 16 , input_t , weight_t >(params, stream);
526- } else if (params.seqlen <= 1024 ) {
527- selective_scan_bwd_launch<64 , 16 , input_t , weight_t >(params, stream);
528- } else {
529- selective_scan_bwd_launch<128 , 16 , input_t , weight_t >(params, stream);
537+
538+ #ifndef USE_ROCM
539+ constexpr int warp_size = 32 ;
540+ #else
541+ constexpr int warp_size = rocprim::warp_size ();
542+ #endif
543+
544+ if (warp_size == 32 ) {
545+ if (params.seqlen <= 128 ) {
546+ selective_scan_bwd_launch<64 , 4 , input_t , weight_t >(params, stream);
547+ } else if (params.seqlen <= 256 ) {
548+ selective_scan_bwd_launch<64 , 8 , input_t , weight_t >(params, stream);
549+ } else if (params.seqlen <= 512 ) {
550+ selective_scan_bwd_launch<64 , 16 , input_t , weight_t >(params, stream);
551+ } else if (params.seqlen <= 1024 ) {
552+ selective_scan_bwd_launch<64 , 16 , input_t , weight_t >(params, stream);
553+ } else {
554+ selective_scan_bwd_launch<128 , 16 , input_t , weight_t >(params, stream);
555+ }
556+ }
557+ #ifdef USE_ROCM
558+ else {
559+ if (params.seqlen <= 256 ) {
560+ selective_scan_bwd_launch<64 , 4 , input_t , weight_t >(params, stream);
561+ } else if (params.seqlen <= 512 ) {
562+ selective_scan_bwd_launch<64 , 8 , input_t , weight_t >(params, stream);
563+ } else if (params.seqlen <= 1024 ) {
564+ selective_scan_bwd_launch<64 , 16 , input_t , weight_t >(params, stream);
565+ } else {
566+ selective_scan_bwd_launch<128 , 16 , input_t , weight_t >(params, stream);
567+ }
530568 }
569+ #endif
531570}
0 commit comments