You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[pytorch][cuda] Improve softmax backward pass native CUDA implementation (pytorch#145866)
This PR is similar to pytorch#122970, but works on the softmax backward pass.
Specifically, it uses shared memory to cache the gradOutput when it can fit in shared memory. Before this PR we were reading gradOutput twice.
On my H100 this seems to improve the softmax backward pass performance by about 5% for problem sizes that fit within shared memory. (Note that this is not the only kernel that runs when you call softmax backward pass -- there is an elementwise kernel that runs before this; optimizing that can be a separate PR).
**Important Note**: Currently the softmax backward pass consists of an [element-wise multiply operator](https://github.com/pytorch/pytorch/blob/7f65a208848205b38445423b7e2e93a2b4994e5e/aten/src/ATen/native/cuda/SoftMax.cu#L1216), followed by [this function](https://github.com/pytorch/pytorch/blob/7f65a208848205b38445423b7e2e93a2b4994e5e/aten/src/ATen/native/cuda/SoftMax.cu#L1062) which calls the `cunn_SoftMaxBackward` kernel. With my change the kernel time reduces by about 12% (see screenshot below), while the total time (including the elementwise) reduces by about 5%.
```
Baseline This PR
N size FP32 bandwidth FP16 bandwidth N size FP32 bandwidth FP16 bandwidth fp32 diff fp16 diff
0 256 134.340966 70.042039 0 256 133.70146 70.342753 -0.48% 0.43%
1 512 233.501185 129.945803 1 512 234.057145 132.933066 0.24% 2.30%
2 1024 340.667966 229.280464 2 1024 338.833265 226.441699 -0.54% -1.24%
3 2048 379.643726 337.452058 3 2048 399.559017 338.432284 5.25% 0.29%
4 4096 416.597537 383.625364 4 4096 428.252403 396.137506 2.80% 3.26%
5 6000 431.198241 384.384384 5 6000 457.744577 406.06275 6.16% 5.64%
6 8192 462.811252 427.292573 6 8192 474.791032 428.281563 2.59% 0.23%
7 10000 464.258731 429.050294 7 10000 483.7643 446.849381 4.20% 4.15%
8 10013 465.199701 429.824179 8 10013 464.904407 428.72184 -0.06% -0.26%
9 10240 477.07359 428.853737 9 10240 485.317024 444.902586 1.73% 3.74%
10 11000 473.038785 430.778663 10 11000 488.161438 453.462162 3.20% 5.27%
11 12000 474.342475 432.594814 11 12000 490.532418 458.427653 3.41% 5.97%
12 16384 487.468854 473.611576 12 16384 488.154406 476.264631 0.14% 0.56%
13 20000 482.029793 465.666186 13 20000 482.147092 483.886193 0.02% 3.91%
14 24000 478.368093 474.159464 14 24000 478.364948 491.447921 0.00% 3.65%
15 32000 476.523796 473.18868 15 32000 476.523796 474.398962 0.00% 0.26%
16 32768 476.104723 477.493634 16 32768 476.704463 477.330606 0.13% -0.03%
17 36864 477.900663 475.472787 17 36864 477.973279 475.728454 0.02% 0.05%
18 40960 477.707561 475.559064 18 40960 478.445017 476.088067 0.15% 0.11%
19 45056 479.169812 475.865134 19 45056 479.143266 475.878202 -0.01% 0.00%
20 49152 477.804907 475.382982 20 49152 477.868404 475.976377 0.01% 0.12%
21 65536 481.274125 478.171806 21 65536 481.537733 478.703926 0.05% 0.11%
22 66000 481.64652 480.095457 22 66000 481.856013 480.466388 0.04% 0.08%
23 68608 481.745774 479.034704 23 68608 481.917596 478.856209 0.04% -0.04%
24 80000 483.409361 480.356529 24 80000 483.330481 480.375277 -0.02% 0.00%
25 98304 480.736301 481.396882 25 98304 480.789858 481.320143 0.01% -0.02%
```
NCU profiler shows lower DRAM fetches with the new kernel:

NCU reports about 12% elapsed time reduction in this kernel alone compared to baseline (and because of other kernels that are run, the overall backward pass time as seen by the user gets reduced by 5%).
I compared the binary size increase by running `python setup.py develop` before and after and diffing the .so files:

libtorch_cuda.so goes from 274,752,224 bytes to 274,787,072 bytes. The increase in size is 34kB which is about 0.01%.
I measured the compilation time for incremental development:
```
touch ./aten/src/ATen/native/cuda/SoftMax.cu
time python setup.py develop
real 0m10.083s
user 0m8.197s
sys 0m3.149s
```
Note that this uses `ccache` and does a bunch of copies and is not just measuring the `nvcc` time. I measured the `nvcc` time separately by capturing the `nvcc` command shown in [1] below and running it on the baseline and modified kernels:
```
# baseline nvcc time for SoftMax.cu
real 0m35.341s
user 0m33.801s
sys 0m1.289s
# this PR's nvcc time for SoftMax.cu
real 0m36.513s
user 0m34.722s
sys 0m1.408s
```
So the `nvcc` time increases by about 1 second, or ~3% of the baseline.
[1] `nvcc` command is here:
```
# This is the nvcc command
/usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/torch/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/torch/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/SoftMax.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o
```
Pull Request resolved: pytorch#145866
Approved by: https://github.com/ngimel
0 commit comments