Skip to content

Commit aa69d73

Browse files
hongxiayangpytorchmergebot
authored andcommitted
[ROCm] fix torch.layer_norm invalid configuration problem when input is large tensor (pytorch#144007)
Fixes pytorch#136291 This PR is to fix the `invalid configuration argument` problem happened on ROCm when input is a large tensor when calling `torch.layer_norm`. ``` File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/functional.py", line 2573, in layer_norm return torch.layer_norm RuntimeError: HIP error: invalid configuration argument ``` After investigation, I found that the reason why this error happened is: The amd compute language runtime checks whether `gridDim.x * blockDim.x` is greater than `std::numeric_limits<uint32_t>::max()` or not. If yes, it will error out with the "invalid configuration argument" message. The fix is to split the whole task to several chunks so that each chunk will not trigger the failure condition. This will ensure the correctness and completeness given the current kernel implementation logic of `vectorized_layer_norm_kernel`. Also added a largeTensor layer_norm unit test `test_layer_norm_large_tensor` with the same shape `[16, 3000, 3000, 16]` as the one used by the pytorch issue pytorch#136291 so that the unit test can check the expected output value to ensure correctness. The future work may include performance optimization of layer_norm and CK layer_norm integration. Pull Request resolved: pytorch#144007 Approved by: https://github.com/eqy
1 parent 6c54963 commit aa69d73

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

aten/src/ATen/native/cuda/layer_norm_kernel.cu

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,12 +745,49 @@ void launch_vectorized_layer_norm_kernel(
745745
auto stream = at::cuda::getCurrentCUDAStream().stream();
746746
const int warp_size = at::cuda::warp_size();
747747
const dim3 threads(warp_size, num_threads() / warp_size, 1);
748-
const dim3 blocks(M);
748+
dim3 blocks(M);
749+
750+
#ifdef USE_ROCM
751+
uint64_t workgroupSize = static_cast<uint64_t>(blocks.x) * static_cast<uint64_t>(threads.x);
752+
// this caused invalid configuration problem
753+
if (workgroupSize > std::numeric_limits<uint32_t>::max()) {
754+
// Fix invalid configuration https://github.com/pytorch/pytorch/issues/136291
755+
blocks.x = std::numeric_limits<uint32_t>::max() / threads.x;
756+
}
757+
#endif
758+
749759
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1);
750760
int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0;
751761
vectorized_layer_norm_kernel<<<blocks, threads, nshared, stream>>>(N, eps, X_data,
752762
gamma_data, beta_data, mean_data, rstd_data, Y_data);
753763
C10_CUDA_KERNEL_LAUNCH_CHECK();
764+
765+
#ifdef USE_ROCM
766+
// the blocks.x contains the max grid x dimention without invalid configuration error
767+
// Fix invalid configuration https://github.com/pytorch/pytorch/issues/136291
768+
// Ensure all elements are processed. Prepare for next round
769+
int64_t remaining = M - blocks.x;
770+
const T* X_data2 = X_data;
771+
T_ACC* mean_data2 = mean_data;
772+
T_ACC* rstd_data2 = rstd_data;
773+
T* Y_data2 = Y_data;
774+
775+
while (remaining > 0) {
776+
X_data2 += N * blocks.x;
777+
mean_data2 += blocks.x;
778+
rstd_data2 += blocks.x;
779+
Y_data2 += N * blocks.x;
780+
781+
blocks.x = (remaining > blocks.x) ? blocks.x : remaining;
782+
783+
vectorized_layer_norm_kernel<<<blocks, threads, nshared, stream>>>(N, eps, X_data2,
784+
gamma_data, beta_data, mean_data2, rstd_data2, Y_data2);
785+
C10_CUDA_KERNEL_LAUNCH_CHECK();
786+
787+
remaining -= blocks.x;
788+
}
789+
#endif
790+
754791
}
755792

756793
template <typename T, typename T_ACC>

test/test_nn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7139,6 +7139,24 @@ def test_layer_norm_eps(self):
71397139
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
71407140
self.assertEqual(ln.forward(x), torch.zeros_like(x))
71417141

7142+
@largeTensorTest("40GB", device="cuda")
7143+
def test_layer_norm_large_tensor(self):
7144+
# test for https://github.com/pytorch/pytorch/issues/136291
7145+
device = torch.device("cuda")
7146+
b, n, dp = 16, 3000, 16
7147+
pairwise_repr = torch.randn(b, n, n, dp)
7148+
7149+
attn_bias_norm = nn.LayerNorm(dp).to(device=device)
7150+
pairwise_repr = pairwise_repr.to(dtype=torch.float32, device=device)
7151+
# we want a smaller copy to compare the results
7152+
pairwise_small = pairwise_repr[-1, -1, -1].detach().clone()
7153+
norm = attn_bias_norm(pairwise_repr)
7154+
norm_small = attn_bias_norm(pairwise_small)
7155+
7156+
self.assertEqual(norm.shape, torch.Size([16, 3000, 3000, 16]))
7157+
# Check output to make sure it is correct.
7158+
torch.testing.assert_close(norm_small, norm[-1, -1, -1])
7159+
71427160
def test_padding_list(self):
71437161
# Padding can be a list, or tuple (regression test for gh-54452)
71447162
x = torch.randn(4, 8, 32, 32)

0 commit comments

Comments
 (0)