Skip to content

Commit cfc08ca

Browse files
apakbinpytorchmergebot
authored andcommitted
[ROCm] NLLLoss (torch.nll_loss) Performance Tuning by Dynamically Selecting # of GPU threads (pytorch#149548)
Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32: ![nll_loss_threads_bests](https://github.com/user-attachments/assets/3be3d465-e3db-44ed-991a-fdfcab03baae) ![nll_loss_heauristic](https://github.com/user-attachments/assets/e82b9788-9b4d-4862-a180-8df7ad298182) Pull Request resolved: pytorch#149548 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony
1 parent 0ed3421 commit cfc08ca

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

aten/src/ATen/native/cuda/Loss.cu

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,13 @@ Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor&
145145
// -----------------------------------
146146
namespace {
147147

148-
constexpr int NLL_LOSS_THREADS = 32;
148+
int nll_loss_threads(int64_t nframe){
149+
#if defined(USE_ROCM)
150+
return std::clamp(1 << static_cast<int64_t>(std::round(std::log2(nframe/16))), 32, 1024);
151+
#else
152+
return 32;
153+
#endif
154+
}
149155

150156
// NOTE(crcrpar): `Byte` support was added for https://github.com/pytorch/pytorch/issues/59765.
151157
#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \
@@ -231,12 +237,13 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
231237
int64_t n_classes,
232238
int64_t ignore_index) {
233239
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
234-
__shared__ accscalar_t sh_inputs[NLL_LOSS_THREADS],
235-
acc_weight[NLL_LOSS_THREADS];
240+
extern __shared__ unsigned char shmem[];
241+
accscalar_t* sh_inputs = reinterpret_cast<accscalar_t*>(shmem);
242+
accscalar_t* acc_weight = reinterpret_cast<accscalar_t*>(shmem + blockDim.x * sizeof(accscalar_t));
236243

237244
sh_inputs[threadIdx.x] = static_cast<accscalar_t>(0);
238245
acc_weight[threadIdx.x] = static_cast<accscalar_t>(0);
239-
for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
246+
for (int i = threadIdx.x; i < nframe; i += blockDim.x) {
240247
index_t t = target[i];
241248
if (t != ignore_index) {
242249
CHECK_INDEX_IN_CLASS(t, n_classes);
@@ -252,7 +259,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
252259
if (threadIdx.x == 0) {
253260
accscalar_t output_acc = 0;
254261
accscalar_t total_weight_acc = 0;
255-
for (int i = 0; i < NLL_LOSS_THREADS; ++i) {
262+
for (int i = 0; i < blockDim.x; ++i) {
256263
output_acc += sh_inputs[i];
257264
total_weight_acc += acc_weight[i];
258265
}
@@ -374,10 +381,11 @@ void nll_loss_forward_out_cuda_template(
374381
"nll_loss_forward_reduce_cuda_kernel_2d_index",
375382
[&] {
376383
using accscalar_t = at::acc_type<scalar_t, /*is_cuda*/true>;
384+
int nthreads = nll_loss_threads(input.size(0));
377385
nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, accscalar_t, index_t>
378386
<<<1,
379-
NLL_LOSS_THREADS,
380-
0,
387+
nthreads,
388+
nthreads * sizeof(accscalar_t) * 2,
381389
at::cuda::getCurrentCUDAStream()>>>(
382390
output.mutable_data_ptr<scalar_t>(),
383391
total_weight.mutable_data_ptr<scalar_t>(),
@@ -456,7 +464,7 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_2d(
456464
const auto grad = -(size_average ? *grad_output / *total_weight
457465
: *grad_output);
458466
459-
for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
467+
for (int i = threadIdx.x; i < nframe; i += blockDim.x) {
460468
const index_t t = target[i];
461469
if (t != ignore_index) {
462470
CHECK_INDEX_IN_CLASS(t, n_classes);
@@ -560,7 +568,7 @@ void nll_loss_backward_out_cuda_template(
560568
"nll_loss_backward_reduce_cuda_kernel_2d_index",
561569
[&] {
562570
nll_loss_backward_reduce_cuda_kernel_2d<scalar_t, index_t>
563-
<<<1, NLL_LOSS_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
571+
<<<1, nll_loss_threads(input.size(0)), 0, at::cuda::getCurrentCUDAStream()>>>(
564572
grad_input.mutable_data_ptr<scalar_t>(),
565573
grad_output.const_data_ptr<scalar_t>(),
566574
target.const_data_ptr<index_t>(),

0 commit comments

Comments
 (0)