Skip to content

Commit b60120d

Browse files
Revert "[ATen][CUDA] Implement 128 bit vectorization v2 (pytorch#145746)"
This reverts commit 81685d8. Reverted pytorch#145746 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking in trunk. See functorch/test_ops.py::TestOperatorsCUDA::test_jvp_nn_functional_multi_head_attention_forward_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/13032483748/job/36358184032) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/81685d81eb86595d169f55a564da26eaafb2ddf5) ([comment](pytorch#145746 (comment)))
1 parent 5215885 commit b60120d

File tree

8 files changed

+21
-74
lines changed

8 files changed

+21
-74
lines changed

aten/src/ATen/native/cuda/CUDAJitLoops.cuh

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct JittedVecKernelCache {
4949
at::cuda::jit::NvrtcFunction vec1;
5050
at::cuda::jit::NvrtcFunction vec2;
5151
at::cuda::jit::NvrtcFunction vec4;
52-
at::cuda::jit::NvrtcFunction vec8;
5352
#ifdef USE_ROCM
53+
at::cuda::jit::NvrtcFunction vec8;
5454
at::cuda::jit::NvrtcFunction vec16;
5555
#endif
5656

@@ -131,30 +131,18 @@ void launch_jitted_vectorized_kernel(
131131
int vec_size = at::cuda::jit::can_vectorize_up_to(
132132
desc, c10::ArrayRef<char*>(data.data(), data.size()));
133133

134-
#ifndef USE_ROCM
135-
const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
136-
const int optimal_vec_size = 16 / static_cast<int>(input_size);
137-
vec_size = std::min<int>(optimal_vec_size, vec_size);
138-
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
139-
// that causes some numerical mismatches with uint8 on sm80 and sm90.
140-
// TODO: Revisit this after CUDA 12.8 update.
141-
if (input_size < 2) {
142-
vec_size = std::min<int>(vec_size, 4);
143-
}
144-
#endif
145-
146134
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
147135
// fn_ptr is set to the appropriate function based on the vec size and GPU used
148136
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
149137

150138
#ifdef USE_ROCM
151139
if (vec_size == 16) {
152140
fn_ptr = &fn_cache.vec16;
141+
} else if (vec_size == 8) {
142+
fn_ptr = &fn_cache.vec8;
153143
} else
154144
#endif
155-
if (vec_size == 8) {
156-
fn_ptr = &fn_cache.vec8;
157-
} else if (vec_size == 4) {
145+
if (vec_size == 4) {
158146
fn_ptr = &fn_cache.vec4;
159147
} else if (vec_size == 2) {
160148
fn_ptr = &fn_cache.vec2;

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
6161
}
6262
}
6363

64-
#ifdef USE_ROCM
6564
template <int io_sizes>
6665
constexpr auto elems_per_thread(){
6766
if constexpr (io_sizes == 1) {
@@ -72,16 +71,6 @@ constexpr auto elems_per_thread(){
7271
return 4;
7372
}
7473
}
75-
#else
76-
template <int io_sizes>
77-
constexpr auto elems_per_thread(){
78-
if constexpr (io_sizes == 1) {
79-
return 16;
80-
} else {
81-
return 8;
82-
}
83-
}
84-
#endif
8574

8675
template <int io_sizes>
8776
constexpr auto io_block_work_size() {
@@ -202,33 +191,21 @@ static inline void launch_vectorized_kernel(
202191
constexpr auto io_size = calc_io_size<func_t>();
203192
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
204193
auto stream = at::cuda::getCurrentCUDAStream();
205-
#ifdef USE_ROCM
206194
int vec_size = memory::can_vectorize_up_to<func_t>(data);
207-
#else
208-
using cpp_type = typename function_traits<func_t>::result_type;
209-
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
210-
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
211-
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
212-
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
213-
// that causes some numerical mismatches with uint8 on sm80 and sm90.
214-
// TODO: Revisit this after CUDA 12.8 update.
215-
if (sizeof(cpp_type) < 2) {
216-
vec_size = std::min<uint16_t>(vec_size, 4);
217-
}
218-
#endif
195+
219196
switch (vec_size) {
220197
#ifdef USE_ROCM
221198
case 16:
222199
vectorized_elementwise_kernel<16, func_t, array_t>
223200
<<<grid, num_threads(), 0, stream>>>(N, f, data);
224201
C10_CUDA_KERNEL_LAUNCH_CHECK();
225202
break;
226-
#endif
227203
case 8:
228204
vectorized_elementwise_kernel<8, func_t, array_t>
229205
<<<grid, num_threads(), 0, stream>>>(N, f, data);
230206
C10_CUDA_KERNEL_LAUNCH_CHECK();
231207
break;
208+
#endif
232209
case 4:
233210
vectorized_elementwise_kernel<4, func_t, array_t>
234211
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
218218
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
219219
#else
220220
// make sure we don't break assumption that we can't have > 4 elements / thread
221-
TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]");
221+
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
222222
#endif
223223
}
224224

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,19 +351,15 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
351351
uint64_t address = reinterpret_cast<uint64_t>(pointer);
352352
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
353353
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
354-
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
355354
#ifdef USE_ROCM
355+
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
356356
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
357357
constexpr int type_size = sizeof(scalar_t);
358358
if (type_size == 1 && (address % vec16_alignment == 0)) {
359359
return 16;
360360
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
361361
return 8;
362362
} else
363-
#else
364-
if (address % vec8_alignment == 0) {
365-
return 8;
366-
} else
367363
#endif
368364
if (address % vec4_alignment == 0) {
369365
return 4;

aten/src/ATen/native/cuda/jit_utils.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ void initializeCudaContext() {
932932
}
933933
}
934934

935+
#ifdef USE_ROCM
935936
int calc_io_size(
936937
const int nInputs,
937938
const int nOutputs,
@@ -951,6 +952,7 @@ int calc_io_size(
951952

952953
return 0;
953954
}
955+
#endif
954956

955957
int calc_thread_work_size(
956958
const int nInputs,
@@ -969,14 +971,7 @@ int calc_thread_work_size(
969971
}
970972
return io_size;
971973
#else
972-
auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type);
973-
TORCH_INTERNAL_ASSERT(io_size > 0);
974-
if (io_size == 1) {
975-
return 16;
976-
} else {
977-
return 8;
978-
}
979-
return io_size;
974+
return JIT_THREAD_WORK_SIZE;
980975
#endif
981976
}
982977

aten/src/ATen/native/cuda/jit_utils.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
6060
if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) {
6161
return 8;
6262
}
63-
#else
64-
if (ip % (8 * default_alignment) == 0) {
65-
return 8;
66-
}
6763
#endif
6864
if (ip % (4 * default_alignment) == 0) {
6965
return 4;
@@ -92,17 +88,15 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*
9288
}
9389

9490
//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
95-
#ifdef USE_ROCM
9691
#define JIT_THREAD_WORK_SIZE 4
97-
#else
98-
#define JIT_THREAD_WORK_SIZE 8
99-
#endif
10092

93+
#ifdef USE_ROCM
10194
int calc_io_size(
10295
const int nInputs,
10396
const int nOutputs,
10497
const c10::ScalarType& inputs_type,
10598
const c10::ScalarType& result_type);
99+
#endif
106100

107101
int calc_thread_work_size(
108102
const int nInputs,

aten/src/ATen/native/cuda/thread_constants.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212
constexpr int num_threads() {
1313
return 256;
1414
}
15-
16-
constexpr int thread_work_size() { return 4; }
1715
#else
1816
constexpr uint32_t num_threads() {
1917
return C10_WARP_SIZE * 4;
2018
}
21-
22-
constexpr int thread_work_size() { return 8; }
2319
#endif
2420

21+
constexpr int thread_work_size() { return 4; }
2522
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

aten/src/ATen/test/cuda_vectorized_test.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) {
4747
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
4848
char *ptr = reinterpret_cast<char *>(buffer1);
4949

50-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
51-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
52-
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
53-
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
54-
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
50+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 4);
51+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 4);
52+
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 4);
53+
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 4);
54+
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 4);
5555

5656
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
5757
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
6565
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
6666
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
6767

68-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
69-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
68+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 4);
69+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 4);
7070
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
7171
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
7272
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);

0 commit comments

Comments
 (0)