Skip to content

Commit ae5cc19

Browse files
ahmadsharif1pytorchmergebot
authored andcommitted
[pytorch][cuda] Improve softmax backward pass native CUDA implementation (pytorch#145866)
This PR is similar to pytorch#122970, but works on the softmax backward pass. Specifically, it uses shared memory to cache the gradOutput when it can fit in shared memory. Before this PR we were reading gradOutput twice. On my H100 this seems to improve the softmax backward pass performance by about 5% for problem sizes that fit within shared memory. (Note that this is not the only kernel that runs when you call softmax backward pass -- there is an elementwise kernel that runs before this; optimizing that can be a separate PR). **Important Note**: Currently the softmax backward pass consists of an [element-wise multiply operator](https://github.com/pytorch/pytorch/blob/7f65a208848205b38445423b7e2e93a2b4994e5e/aten/src/ATen/native/cuda/SoftMax.cu#L1216), followed by [this function](https://github.com/pytorch/pytorch/blob/7f65a208848205b38445423b7e2e93a2b4994e5e/aten/src/ATen/native/cuda/SoftMax.cu#L1062) which calls the `cunn_SoftMaxBackward` kernel. With my change the kernel time reduces by about 12% (see screenshot below), while the total time (including the elementwise) reduces by about 5%. ``` Baseline This PR N size FP32 bandwidth FP16 bandwidth N size FP32 bandwidth FP16 bandwidth fp32 diff fp16 diff 0 256 134.340966 70.042039 0 256 133.70146 70.342753 -0.48% 0.43% 1 512 233.501185 129.945803 1 512 234.057145 132.933066 0.24% 2.30% 2 1024 340.667966 229.280464 2 1024 338.833265 226.441699 -0.54% -1.24% 3 2048 379.643726 337.452058 3 2048 399.559017 338.432284 5.25% 0.29% 4 4096 416.597537 383.625364 4 4096 428.252403 396.137506 2.80% 3.26% 5 6000 431.198241 384.384384 5 6000 457.744577 406.06275 6.16% 5.64% 6 8192 462.811252 427.292573 6 8192 474.791032 428.281563 2.59% 0.23% 7 10000 464.258731 429.050294 7 10000 483.7643 446.849381 4.20% 4.15% 8 10013 465.199701 429.824179 8 10013 464.904407 428.72184 -0.06% -0.26% 9 10240 477.07359 428.853737 9 10240 485.317024 444.902586 1.73% 3.74% 10 11000 473.038785 430.778663 10 11000 488.161438 453.462162 3.20% 5.27% 11 12000 474.342475 432.594814 11 12000 490.532418 458.427653 3.41% 5.97% 12 16384 487.468854 473.611576 12 16384 488.154406 476.264631 0.14% 0.56% 13 20000 482.029793 465.666186 13 20000 482.147092 483.886193 0.02% 3.91% 14 24000 478.368093 474.159464 14 24000 478.364948 491.447921 0.00% 3.65% 15 32000 476.523796 473.18868 15 32000 476.523796 474.398962 0.00% 0.26% 16 32768 476.104723 477.493634 16 32768 476.704463 477.330606 0.13% -0.03% 17 36864 477.900663 475.472787 17 36864 477.973279 475.728454 0.02% 0.05% 18 40960 477.707561 475.559064 18 40960 478.445017 476.088067 0.15% 0.11% 19 45056 479.169812 475.865134 19 45056 479.143266 475.878202 -0.01% 0.00% 20 49152 477.804907 475.382982 20 49152 477.868404 475.976377 0.01% 0.12% 21 65536 481.274125 478.171806 21 65536 481.537733 478.703926 0.05% 0.11% 22 66000 481.64652 480.095457 22 66000 481.856013 480.466388 0.04% 0.08% 23 68608 481.745774 479.034704 23 68608 481.917596 478.856209 0.04% -0.04% 24 80000 483.409361 480.356529 24 80000 483.330481 480.375277 -0.02% 0.00% 25 98304 480.736301 481.396882 25 98304 480.789858 481.320143 0.01% -0.02% ``` NCU profiler shows lower DRAM fetches with the new kernel: ![image](https://github.com/user-attachments/assets/f3606725-d8fc-4ea5-ae6d-9c188bf32d72) NCU reports about 12% elapsed time reduction in this kernel alone compared to baseline (and because of other kernels that are run, the overall backward pass time as seen by the user gets reduced by 5%). I compared the binary size increase by running `python setup.py develop` before and after and diffing the .so files: ![image](https://github.com/user-attachments/assets/8e6cee2e-3c7a-4fa4-8836-954047ce8ffc) libtorch_cuda.so goes from 274,752,224 bytes to 274,787,072 bytes. The increase in size is 34kB which is about 0.01%. I measured the compilation time for incremental development: ``` touch ./aten/src/ATen/native/cuda/SoftMax.cu time python setup.py develop real 0m10.083s user 0m8.197s sys 0m3.149s ``` Note that this uses `ccache` and does a bunch of copies and is not just measuring the `nvcc` time. I measured the `nvcc` time separately by capturing the `nvcc` command shown in [1] below and running it on the baseline and modified kernels: ``` # baseline nvcc time for SoftMax.cu real 0m35.341s user 0m33.801s sys 0m1.289s # this PR's nvcc time for SoftMax.cu real 0m36.513s user 0m34.722s sys 0m1.408s ``` So the `nvcc` time increases by about 1 second, or ~3% of the baseline. [1] `nvcc` command is here: ``` # This is the nvcc command /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/torch/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/torch/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/SoftMax.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o ``` Pull Request resolved: pytorch#145866 Approved by: https://github.com/ngimel
1 parent 8c80c13 commit ae5cc19

File tree

2 files changed

+152
-15
lines changed

2 files changed

+152
-15
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,64 @@ cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outsc
877877
}
878878
}
879879

880+
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
881+
__global__ void
882+
cunn_SoftMaxBackwardSmem(scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput, int64_t classes)
883+
{
884+
// The first smem segment is used to cache input values and the last
885+
// segment is used for thread block reductions
886+
extern __shared__ unsigned char smem[];
887+
auto smem_input_cache = reinterpret_cast<outscalar_t*>(smem);
888+
auto smem_reduction_cache = reinterpret_cast<accscalar_t*>(smem +
889+
classes * sizeof(outscalar_t));
890+
891+
gradInput += static_cast<int64_t>(blockIdx.x) * classes;
892+
output += static_cast<int64_t>(blockIdx.x) * classes;
893+
gradOutput += static_cast<int64_t>(blockIdx.x) * classes;
894+
895+
accscalar_t threadSum = 0;
896+
897+
using LoadT = at::native::memory::aligned_vector<outscalar_t, ILP>;
898+
const LoadT* const gradOutput_vec_ptr = reinterpret_cast<const LoadT*>(gradOutput);
899+
LoadT* const smem_gradOutput_cache_vec_ptr = reinterpret_cast<LoadT*>(smem_input_cache);
900+
901+
// Download inputs to shared memory while doing the first step
902+
// in sum calculation
903+
for (int32_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
904+
LoadT crnt_vec = gradOutput_vec_ptr[offset];
905+
smem_gradOutput_cache_vec_ptr[offset] = crnt_vec;
906+
907+
#pragma unroll
908+
for (int i = 0; i < ILP; ++i) {
909+
threadSum = threadSum + crnt_vec.val[i];
910+
}
911+
}
912+
913+
// We need a __syncthreads() here to be safe. However, blockReduceWarp's code
914+
// calls a __syncthreads() before reading shared memory so we are safe.
915+
916+
accscalar_t sum_k = blockReduceWarp<Add, accscalar_t>(smem_reduction_cache, threadSum, Add<accscalar_t>(), accscalar_t(0));
917+
918+
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum_k);
919+
920+
// Use vectorized stores to save the output
921+
using StoreT = at::native::memory::aligned_vector<scalar_t, ILP>;
922+
StoreT* gradInput_vec_ptr = reinterpret_cast<StoreT*>(gradInput);
923+
const LoadT* const output_vec_ptr = reinterpret_cast<const LoadT*>(output);
924+
for (int32_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
925+
LoadT crnt_vec = smem_gradOutput_cache_vec_ptr[offset];
926+
LoadT crnt_out = output_vec_ptr[offset];
927+
StoreT out_vec;
928+
929+
#pragma unroll
930+
for (int i = 0; i < ILP; ++i) {
931+
out_vec.val[i] = epilogue(crnt_vec.val[i], crnt_out.val[i]);
932+
}
933+
934+
gradInput_vec_ptr[offset] = out_vec;
935+
}
936+
}
937+
880938
template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
881939
Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){
882940
if (half_to_float) {
@@ -1058,6 +1116,39 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
10581116
return output;
10591117
}
10601118

1119+
template<typename input_t, typename output_t, typename accscalar_t, template<typename, typename, typename> class Epilogue>
1120+
void dispatch_host_softmax_backward(int64_t dim_size, dim3 grid, Tensor &grad, Tensor &output, const Tensor &gI) {
1121+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1122+
constexpr int ILP = sizeof(float4) / sizeof(output_t);
1123+
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1124+
1125+
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1126+
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
1127+
smem_reduction_sz) / sizeof(output_t);
1128+
bool can_use_smem = dim_size < max_elements_per_smem;
1129+
can_use_smem &= (!(reinterpret_cast<const uintptr_t>(gI.const_data_ptr<input_t>()) % ALIGN_BYTES));
1130+
can_use_smem &= (!(reinterpret_cast<const uintptr_t>(output.const_data_ptr<output_t>()) % ALIGN_BYTES));
1131+
can_use_smem &= !(reinterpret_cast<const uintptr_t>(grad.const_data_ptr<output_t>()) % ALIGN_BYTES);
1132+
can_use_smem &= !(dim_size % ILP);
1133+
// This should not be needed on current generation GPUs because the size of shared memory is so low.
1134+
// But we add this check to be defensive and future-proof just in case shared memory size goes up
1135+
// to be so large as to requires 64-bits of addressing.
1136+
can_use_smem &= (dim_size < std::numeric_limits<int32_t>::max());
1137+
1138+
if (can_use_smem) {
1139+
size_t smem_sz = dim_size * sizeof(output_t) + smem_reduction_sz;
1140+
cunn_SoftMaxBackwardSmem<ILP, input_t, accscalar_t, output_t, Epilogue>
1141+
<<<grid, block, smem_sz, stream>>>(
1142+
gI.mutable_data_ptr<input_t>(), output.const_data_ptr<output_t>(), grad.const_data_ptr<output_t>(), dim_size);
1143+
} else {
1144+
cunn_SoftMaxBackward<ILP, input_t, accscalar_t, output_t, Epilogue>
1145+
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1146+
gI.mutable_data_ptr<input_t>(), output.const_data_ptr<output_t>(), grad.const_data_ptr<output_t>(), dim_size
1147+
);
1148+
}
1149+
C10_CUDA_KERNEL_LAUNCH_CHECK();
1150+
}
1151+
10611152
template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
10621153
void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float, const Tensor &gI){
10631154
int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
@@ -1099,13 +1190,7 @@ void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t d
10991190
remaining -= chunk_size;
11001191
}
11011192
} else {
1102-
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
1103-
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1104-
cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
1105-
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1106-
gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(), dim_size
1107-
);
1108-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1193+
dispatch_host_softmax_backward<scalar_t, scalar_t, accscalar_t, Epilogue>(dim_size, grid, grad, output, gI);
11091194
}
11101195
} else {
11111196
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
@@ -1123,13 +1208,7 @@ void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t d
11231208
remaining -= chunk_size;
11241209
}
11251210
} else {
1126-
constexpr int ILP = sizeof(float4) / sizeof(accscalar_t);
1127-
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1128-
cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
1129-
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1130-
gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<accscalar_t>(), grad.const_data_ptr<accscalar_t>(), dim_size
1131-
);
1132-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1211+
dispatch_host_softmax_backward<scalar_t, accscalar_t, accscalar_t, Epilogue>(dim_size, grid, grad, output, gI);
11331212
}
11341213
}
11351214
});

test/test_nn.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13010,11 +13010,69 @@ def test_softmax_forward_64bit_indexing(self, device):
1301013010
@largeTensorTest("20GB", "cuda")
1301113011
def test_softmax_backward_64bit_indexing(self, device):
1301213012
for numel in (2147483650, 2147483650 + 1):
13013-
x = torch.empty([1, 1, numel], device=device, dtype=torch.float16)
13013+
x = torch.ones([1, 1, numel], device=device, dtype=torch.float16)
1301413014
x.fill_(1.0 / numel)
1301513015
out = torch._softmax_backward_data(x, x, 2, x.dtype)
1301613016
self.assertEqual(out[0, 0, 0], 1 / numel)
1301713017

13018+
@onlyCUDA
13019+
def test_softmax_backward_smem(self, device):
13020+
torch.manual_seed(0)
13021+
# We use smem for tensors that have > 1024 elements and size >= 4096 bytes.
13022+
numel = 2048
13023+
for dtype in [torch.half, torch.float32]:
13024+
output = torch.rand([numel], device=device, dtype=dtype)
13025+
grad_output = torch.rand([numel], device=device, dtype=dtype)
13026+
result = torch._softmax_backward_data(grad_output, output, 0, output.dtype)
13027+
expected_result = torch._softmax_backward_data(grad_output.cpu(), output.cpu(), 0, dtype)
13028+
self.assertEqual(expected_result, result)
13029+
13030+
@onlyCUDA
13031+
def test_softmax_backward_without_fully_vectorized(self, device):
13032+
torch.manual_seed(0)
13033+
# We don't use smem here because the size of the elements does not divide
13034+
# ILP cleanly. ILP is defined as sizeof(float4) / sizeof(dtype). Since ILP
13035+
# is 4 and numel is not divisible by 4, we don't use shared memory here.
13036+
numel = 2048 + 1
13037+
for dtype in [torch.half, torch.float32]:
13038+
output = torch.rand([numel], device=device, dtype=dtype)
13039+
grad_output = torch.rand([numel], device=device, dtype=dtype) * (1.0 / numel)
13040+
result = torch._softmax_backward_data(grad_output, output, 0, output.dtype)
13041+
expected_result = torch._softmax_backward_data(grad_output.cpu(), output.cpu(), 0, dtype)
13042+
self.assertEqual(expected_result, result)
13043+
13044+
def make_unaligned_1d_tensor_of_rand(self, numel, device, dtype):
13045+
# It's hard to get pytorch to return us a tensor that is not aligned to 16
13046+
# bytes. To work around that, we create an aligned tensor and create a
13047+
# slice of it that is not aligned.
13048+
output = torch.ones([numel + 1], device=device, dtype=dtype)
13049+
unaligned_output = output[1:]
13050+
self.assertNotEqual(unaligned_output.data_ptr() % 16, 0)
13051+
return unaligned_output
13052+
13053+
@onlyCUDA
13054+
def test_softmax_backward_unaligned_output(self, device):
13055+
torch.manual_seed(0)
13056+
# We don't use smem here because the output is not aligned to 16 bytes.
13057+
numel = 2048
13058+
for dtype in [torch.half, torch.float32]:
13059+
unaligned_output = self.make_unaligned_1d_tensor_of_rand(numel, device, dtype)
13060+
grad_output = torch.rand([numel], device=device, dtype=dtype) * (1.0 / numel)
13061+
result = torch._softmax_backward_data(grad_output, unaligned_output, 0, unaligned_output.dtype)
13062+
expected_result = torch._softmax_backward_data(grad_output.cpu(), unaligned_output.cpu(), 0, dtype)
13063+
self.assertEqual(expected_result, result)
13064+
13065+
@onlyCUDA
13066+
def test_softmax_backward_unaligned_grad_output(self, device):
13067+
torch.manual_seed(0)
13068+
numel = 2048
13069+
for dtype in [torch.half, torch.float32]:
13070+
output = torch.rand([numel], device=device, dtype=dtype)
13071+
unaligned_grad_output = self.make_unaligned_1d_tensor_of_rand(numel, device, dtype) * (1.0 / numel)
13072+
result = torch._softmax_backward_data(unaligned_grad_output, output, 0, output.dtype)
13073+
expected_result = torch._softmax_backward_data(unaligned_grad_output.cpu(), output.cpu(), 0, dtype)
13074+
self.assertEqual(expected_result, result)
13075+
1301813076
# reference issue: https://github.com/pytorch/pytorch/issues/68248
1301913077
@onlyCUDA
1302013078
def test_adaptiveavg_pool1d_shmem(self, device):

0 commit comments

Comments
 (0)