@@ -145,7 +145,13 @@ Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor&
145145// -----------------------------------
146146namespace {
147147
148- constexpr int NLL_LOSS_THREADS = 32 ;
148+ int nll_loss_threads (int64_t nframe){
149+ #if defined(USE_ROCM)
150+ return std::clamp (1 << static_cast <int64_t >(std::round (std::log2 (nframe/16 ))), 32 , 1024 );
151+ #else
152+ return 32 ;
153+ #endif
154+ }
149155
150156// NOTE(crcrpar): `Byte` support was added for https://github.com/pytorch/pytorch/issues/59765.
151157#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES (TYPE, NAME, ...) \
@@ -231,12 +237,13 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
231237 int64_t n_classes,
232238 int64_t ignore_index) {
233239 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
234- __shared__ accscalar_t sh_inputs[NLL_LOSS_THREADS],
235- acc_weight[NLL_LOSS_THREADS];
240+ extern __shared__ unsigned char shmem[];
241+ accscalar_t * sh_inputs = reinterpret_cast <accscalar_t *>(shmem);
242+ accscalar_t * acc_weight = reinterpret_cast <accscalar_t *>(shmem + blockDim .x * sizeof (accscalar_t ));
236243
237244 sh_inputs[threadIdx .x ] = static_cast <accscalar_t >(0 );
238245 acc_weight[threadIdx .x ] = static_cast <accscalar_t >(0 );
239- for (int i = threadIdx .x ; i < nframe; i += NLL_LOSS_THREADS ) {
246+ for (int i = threadIdx .x ; i < nframe; i += blockDim . x ) {
240247 index_t t = target[i];
241248 if (t != ignore_index) {
242249 CHECK_INDEX_IN_CLASS (t, n_classes);
@@ -252,7 +259,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
252259 if (threadIdx .x == 0 ) {
253260 accscalar_t output_acc = 0 ;
254261 accscalar_t total_weight_acc = 0 ;
255- for (int i = 0 ; i < NLL_LOSS_THREADS ; ++i) {
262+ for (int i = 0 ; i < blockDim . x ; ++i) {
256263 output_acc += sh_inputs[i];
257264 total_weight_acc += acc_weight[i];
258265 }
@@ -374,10 +381,11 @@ void nll_loss_forward_out_cuda_template(
374381 " nll_loss_forward_reduce_cuda_kernel_2d_index" ,
375382 [&] {
376383 using accscalar_t = at::acc_type<scalar_t , /* is_cuda*/ true >;
384+ int nthreads = nll_loss_threads (input.size (0 ));
377385 nll_loss_forward_reduce_cuda_kernel_2d<scalar_t , accscalar_t , index_t >
378386 <<<1 ,
379- NLL_LOSS_THREADS ,
380- 0 ,
387+ nthreads ,
388+ nthreads * sizeof ( accscalar_t ) * 2 ,
381389 at::cuda::getCurrentCUDAStream ()>>>(
382390 output.mutable_data_ptr<scalar_t >(),
383391 total_weight.mutable_data_ptr<scalar_t>(),
@@ -456,7 +464,7 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_2d(
456464 const auto grad = -(size_average ? *grad_output / *total_weight
457465 : *grad_output);
458466
459- for (int i = threadIdx .x ; i < nframe; i += NLL_LOSS_THREADS ) {
467+ for (int i = threadIdx .x ; i < nframe; i += blockDim . x ) {
460468 const index_t t = target[i];
461469 if (t != ignore_index) {
462470 CHECK_INDEX_IN_CLASS (t, n_classes);
@@ -560,7 +568,7 @@ void nll_loss_backward_out_cuda_template(
560568 " nll_loss_backward_reduce_cuda_kernel_2d_index" ,
561569 [&] {
562570 nll_loss_backward_reduce_cuda_kernel_2d<scalar_t , index_t >
563- <<<1 , NLL_LOSS_THREADS , 0 , at::cuda::getCurrentCUDAStream()>>> (
571+ <<<1 , nll_loss_threads(input.size( 0 )) , 0 , at::cuda::getCurrentCUDAStream()>>> (
564572 grad_input.mutable_data_ptr <scalar_t >(),
565573 grad_output.const_data_ptr <scalar_t >(),
566574 target.const_data_ptr <index_t >(),
0 commit comments