Skip to content

Commit 99ae7d4

Browse files
ngimelpytorchmergebot
authored andcommitted
Reland fast gather and index implementation (pytorch#151917)
This PR reapplies pytorch#151490 and pytorch#151753 together, and adds some missing checks when applying the fast path. Previously missed checks: 1) indexing path has the stride in the indexed dimension in bytes, gather path has the stride in the indexed dimension in elements. When checking if fast path is applicable, I didn't take this difference into account, and still multiplied the indexing stride by element size. Fixed and test added 2) We want to take fast path only when we are copying contiguous equally spaced slices of inputs + all the necessary alignment requirements. The effective tensor size should be 2d (after all possible flattening is applied), the index stride in the last dimension should be 0, and, since in the kernel we are not applying non-indexing-related offsets to src tensor, the src tensor stride in the second dimension should be 0. This automatically happens for gather with dim=0, so I didn't put in an explicit condition for this. Sometimes all conditions except first dim "effective" stride equal to 0 are satisfied for scatter on non-zero dim, when index size in the indexing dimension is 1 and thus it is collapsed (dimensions of size 1 are always collapsed), e.g. ``` # test gather along 1st dim that can accidentally trigger fast path # because due to index dimension in the gather dim being 1 # an unexpected squashing in tensorIterator happens src = make_tensor((16, 2, 16), device=device, dtype=dtype) ind = torch.randint(2, (16, 1), device=device).view(16, 1, 1).expand(16, 1, 16) res = torch.gather(src, dim=1, index=ind) if res.device.type == "cuda": ref_cpu = torch.gather(src.cpu(), dim=1, index=ind.cpu()) self.assertEqual(res.cpu(), ref_cpu, atol=0, rtol=0) ``` Note that if index size here was (16, 2, 16) instead of (16, 1, 16) then the middle dimension could not be collapsed and we wouldn't end up incorrectly taking fast path. We could update the kernel to take this stride into account when computing offsets into src tensor, or we could specifically disallow non-zero stride on the first dimension. I took the second path for now. Pull Request resolved: pytorch#151917 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/Skylion007
1 parent 69e41ce commit 99ae7d4

File tree

8 files changed

+326
-129
lines changed

8 files changed

+326
-129
lines changed

aten/src/ATen/native/cuda/IndexKernel.cu

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <ATen/native/cuda/Loops.cuh>
1515
#include <ATen/native/cuda/KernelUtils.cuh>
1616
#include <ATen/native/quantized/IndexKernel.h>
17+
#include <ATen/native/cuda/MemoryAccess.cuh>
18+
#include <ATen/native/cuda/IndexKernelUtils.h>
1719

1820
#include <c10/core/Scalar.h>
1921

@@ -52,7 +54,7 @@ static void launch_kernel(const int64_t N, const func_t& f) {
5254
}
5355

5456
template <typename func_t>
55-
void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f) {
57+
void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f, const bool is_gather_like) {
5658
const auto num_indices = index_size.size();
5759
AT_ASSERT(num_indices == index_stride.size());
5860
AT_ASSERT(static_cast<int64_t>(num_indices) == iter.ntensors() - 2);
@@ -63,11 +65,31 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
6365

6466
if (!iter.can_use_32bit_indexing()) {
6567
for (auto& sub_iter : iter.with_32bit_indexing()) {
66-
gpu_index_kernel(sub_iter, index_size, index_stride, f);
68+
gpu_index_kernel(sub_iter, index_size, index_stride, f, is_gather_like);
6769
}
6870
return;
6971
}
7072

73+
74+
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
75+
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
76+
77+
if (is_gather_like && num_indices==1) {
78+
const size_t element_size = iter.element_size(0);
79+
constexpr size_t alignment = 16;
80+
if (at::native::fast_gather_kernel_eligible<alignment>(iter, out_ptr, in_ptr, index_stride[0], element_size)) {
81+
auto slice_size = iter.shape()[0] * element_size;
82+
auto num_ind = iter.shape()[1];
83+
auto ind_dim_size = index_size[0];
84+
auto inp_stride_bytes = index_stride[0];
85+
auto out_stride_bytes = iter.strides(0)[1];
86+
if (iter.numel() == 0) return;
87+
at::native::vectorized_gather_kernel_launch<alignment>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
88+
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
89+
return;
90+
}
91+
}
92+
7193
auto sizes = std::array<int64_t, MAX_DIMS>{};
7294
auto strides = std::array<int64_t, MAX_DIMS>{};
7395
auto index_ptrs = std::array<char*, MAX_DIMS>{};
@@ -77,8 +99,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
7799
index_ptrs[i] = (char*)iter.data_ptr(i + 2);
78100
}
79101

80-
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
81-
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
82102

83103
auto offset_calc = make_offset_calculator<3>(iter);
84104
launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), [=]__device__(int idx) {
@@ -183,14 +203,14 @@ template <typename scalar_t>
183203
void index_kernel_impl(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
184204
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
185205
*reinterpret_cast<scalar_t*>(out_data) = *reinterpret_cast<const scalar_t*>(in_data + offset);
186-
});
206+
}, true);
187207
}
188208

189209
template <typename scalar_t>
190210
void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
191211
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
192212
*reinterpret_cast<scalar_t*>(out_data + offset) = *reinterpret_cast<const scalar_t*>(in_data);
193-
});
213+
}, false);
194214
}
195215

196216
static void index_kernel(
@@ -280,7 +300,7 @@ void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef ind
280300
// The replacement should generate the same PTX as std::clamp. See https://godbolt.org/z/Wde9KW3v4
281301
qvalue = (qvalue < qmin) ? qmin : (qmax < qvalue) ? qmax : qvalue;
282302
*(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
283-
});
303+
}, false);
284304
});
285305
}
286306

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <ATen/native/cuda/MemoryAccess.cuh>
3+
4+
#include <c10/macros/Macros.h>
5+
#include <c10/util/Exception.h>
6+
#include <ATen/native/cuda/Loops.cuh>
7+
#include <ATen/ceil_div.h>
8+
9+
namespace at::native {
10+
template <int Alignment>
11+
__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
12+
int64_t ind = idx[blockIdx.x];
13+
if (allow_neg_indices) {
14+
ind = (ind < 0) ? ind + ind_dim_size : ind;
15+
}
16+
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
17+
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
18+
if (off >= slice_size) return;
19+
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
20+
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
21+
}
22+
23+
24+
25+
template <int64_t Alignment>
26+
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
27+
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
28+
29+
constexpr int64_t max_num_threads=256;
30+
auto num_threads = at::round_up(
31+
at::ceil_div(slice_size_in_bytes, Alignment),
32+
static_cast<int64_t>(C10_WARP_SIZE));
33+
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
34+
auto block = std::min(max_num_threads, num_threads);
35+
vectorized_gather_kernel<Alignment><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
36+
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
37+
C10_CUDA_KERNEL_LAUNCH_CHECK();
38+
}
39+
40+
// explicit template instantiation
41+
template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
42+
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
43+
44+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
#include <cstdint>
3+
#include <ATen/native/TensorIterator.h>
4+
#include <ATen/native/cuda/MemoryAccess.cuh>
5+
6+
namespace at::native {
7+
8+
template<int alignment>
9+
inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const out_ptr, char * const in_ptr, const size_t index_stride_bytes, const size_t element_size) {
10+
using at::native::memory::get_alignment;
11+
const auto index_element_size = iter.element_size(2);
12+
//TensorIterator strides and sizes are ordered fastest moving to slowest moving,
13+
//in contrast to regular sizes
14+
// we need contiguous source and dst slices and aligned pointers and strides and slice size to do vectorized loads
15+
// also we need idx to be expanded in the last dimension so we can copy entire slices
16+
// and we need the src tensor to keep 0 stride from restriding
17+
// (it could have been deleted by dimension collapse, in this case iterator would still be 2d
18+
// but we cannot use fast path)
19+
20+
return iter.ndim() == 2 && iter.strides(2)[0]==0 && iter.strides(2)[1]==index_element_size &&
21+
static_cast<size_t>(iter.strides(0)[0])==element_size &&
22+
static_cast<size_t>(iter.strides(1)[0])==element_size && static_cast<size_t>(iter.strides(1)[1] == 0) &&
23+
get_alignment(out_ptr) == alignment && get_alignment(in_ptr) == alignment &&
24+
get_alignment(static_cast<size_t>(iter.shape()[0] * element_size)) == alignment &&
25+
get_alignment(static_cast<size_t>(index_stride_bytes)) == alignment &&
26+
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
27+
}
28+
29+
template <int64_t Alignment>
30+
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
31+
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
32+
bool allow_neg_indices=false);
33+
34+
35+
}

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,4 +536,123 @@ inline int can_vectorize_up_to(array_t pointers) {
536536
return result;
537537
}
538538

539+
540+
541+
template <typename T>
542+
__inline__ size_t get_alignment(T ptr_or_size) {
543+
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
544+
if (val % 16 == 0) {
545+
return 16;
546+
} else if (val % 8 == 0) {
547+
return 8;
548+
} else if (val % 4 == 0) {
549+
return 4;
550+
} else if (val % 2 == 0) {
551+
return 2;
552+
} else {
553+
return 1;
554+
}
555+
}
556+
557+
template <>
558+
__inline__ size_t get_alignment<size_t>(size_t size) {
559+
return get_alignment(reinterpret_cast<void*>(size));
560+
}
561+
562+
template <bool Value, class... Args>
563+
inline constexpr bool dependent_bool_value = Value;
564+
565+
template <class... Args>
566+
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
567+
568+
template <int Size>
569+
union Vec;
570+
571+
template <>
572+
union Vec<4> {
573+
uint16_t u16[2];
574+
uint32_t u32, as_scalar;
575+
float f32;
576+
};
577+
578+
template <>
579+
union Vec<8> {
580+
uint16_t u16[4];
581+
uint32_t u32[2];
582+
uint64_t u64, as_scalar;
583+
float f32[2];
584+
};
585+
586+
template <>
587+
union alignas(16) Vec<16> {
588+
uint16_t u16[8];
589+
uint32_t u32[4];
590+
uint64_t u64[2];
591+
uint4 u128, as_scalar;
592+
float f32[4];
593+
};
594+
595+
template <int Alignment, typename T>
596+
__device__ __inline__ Vec<Alignment> ld_vec(const T* addr) {
597+
Vec<Alignment> vec;
598+
if constexpr (Alignment == 16) {
599+
#if defined(USE_ROCM)
600+
vec.u128 = *reinterpret_cast<const uint4*>(addr);
601+
} else if constexpr (Alignment == 8) {
602+
vec.u64 = *reinterpret_cast<const uint64_t*>(addr);
603+
} else if constexpr (Alignment == 4) {
604+
vec.u32 = *reinterpret_cast<const uint32_t*>(addr);
605+
#else
606+
asm("ld.global.v4.u32 {%0,%1,%2,%3}, [%4];"
607+
: "=r"(vec.u32[0]), "=r"(vec.u32[1]), "=r"(vec.u32[2]), "=r"(vec.u32[3])
608+
: "l"(addr)
609+
: "memory");
610+
} else if constexpr (Alignment == 8) {
611+
asm("ld.global.v2.u32 {%0,%1}, [%2];"
612+
: "=r"(vec.u32[0]), "=r"(vec.u32[1])
613+
: "l"(addr)
614+
: "memory");
615+
} else if constexpr (Alignment == 4) {
616+
asm("ld.global.u32 %0, [%1];" : "=r"(vec.u32) : "l"(addr) : "memory");
617+
#endif
618+
} else {
619+
static_assert(dependent_false<T>);
620+
}
621+
return vec;
622+
}
623+
624+
template <int Alignment, typename T>
625+
__device__ __inline__ void st_vec(T* addr, const Vec<Alignment>& vec) {
626+
if constexpr (Alignment == 16) {
627+
#if defined(USE_ROCM)
628+
reinterpret_cast<uint64_t*>(addr)[0] = vec.u64[0];
629+
reinterpret_cast<uint64_t*>(addr)[1] = vec.u64[1];
630+
} else if constexpr (Alignment == 8) {
631+
*reinterpret_cast<uint64_t*>(addr) = vec.u64;
632+
} else if constexpr (Alignment == 4) {
633+
*reinterpret_cast<uint32_t*>(addr) = vec.u32;
634+
#else
635+
asm("st.global.v4.u32 [%0], {%1,%2,%3,%4};"
636+
:
637+
: "l"(addr),
638+
"r"(vec.u32[0]),
639+
"r"(vec.u32[1]),
640+
"r"(vec.u32[2]),
641+
"r"(vec.u32[3])
642+
: "memory");
643+
} else if constexpr (Alignment == 8) {
644+
asm("st.global.v2.u32 [%0], {%1,%2};"
645+
:
646+
: "l"(addr), "r"(vec.u32[0]), "r"(vec.u32[1])
647+
: "memory");
648+
} else if constexpr (Alignment == 4) {
649+
asm("st.global.u32 [%0], %1;" : : "l"(addr), "r"(vec.u32) : "memory");
650+
#endif
651+
} else {
652+
static_assert(dependent_false<T>);
653+
}
654+
}
655+
656+
657+
539658
} // namespace at::native::memory

aten/src/ATen/native/cuda/ScatterGatherKernel.cu

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/native/TensorAdvancedIndexing.h>
3-
43
#include <ATen/core/Tensor.h>
54
#include <ATen/Dispatch.h>
5+
#include <ATen/ceil_div.h>
66
#include <ATen/MemoryOverlap.h>
77

88
#include <ATen/native/ScatterGatherChecks.h>
99
#include <ATen/native/ReduceOpsUtils.h>
10-
#include <ATen/native/TensorIterator.h>
11-
10+
#include <ATen/native/cuda/IndexKernelUtils.h>
1211
#include <ATen/native/cuda/Loops.cuh>
1312
#include <ATen/native/cuda/KernelUtils.cuh>
13+
#include <ATen/native/cuda/MemoryAccess.cuh>
1414
#include <ATen/cuda/detail/OffsetCalculator.cuh>
1515
#include <ATen/cuda/Atomic.cuh>
1616
#include <ATen/cuda/CUDAContext.h>
@@ -116,7 +116,6 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
116116
C10_CUDA_KERNEL_LAUNCH_CHECK();
117117
}
118118

119-
120119
template <bool is_scatter_like, typename scalar_t>
121120
struct _cuda_scatter_gather_internal_kernel {
122121
template <typename func_t>
@@ -140,13 +139,29 @@ struct _cuda_scatter_gather_internal_kernel {
140139
char* src_ptr = (char*)iter.data_ptr(1);
141140
char* index_ptr = (char*)iter.data_ptr(2);
142141

142+
if constexpr (!is_scatter_like) {
143+
// we can go to faster path if we are indexing on the first dim
144+
// the dst and src are contiguous and all the dims and pts are multiple of 16
145+
constexpr size_t element_size = sizeof(scalar_t);
146+
constexpr size_t alignment = 16;
147+
if (at::native::fast_gather_kernel_eligible<alignment>(iter, self_ptr, src_ptr, index_stride * element_size, element_size)) {
148+
auto slice_size = iter.shape()[0] * element_size;
149+
auto num_ind = iter.shape()[1];
150+
auto ind_dim_size = index_size;
151+
auto inp_stride_bytes = index_stride * element_size;
152+
auto out_stride_bytes = iter.strides(0)[1];
153+
if (iter.numel() == 0) return;
154+
at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
155+
return;
156+
}
157+
}
143158
auto offset_calc = make_offset_calculator<3>(iter);
144159
auto loop = [=]C10_DEVICE(int i) {
145160
auto offsets = offset_calc.get(i);
146161

147162
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
148163
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
149-
&& "index out of bounds");
164+
&& "scatter gather kernel index out of bounds");
150165

151166
f(
152167
(scalar_t*)(self_ptr + offsets[0]),
@@ -157,6 +172,7 @@ struct _cuda_scatter_gather_internal_kernel {
157172
};
158173

159174
_launch_scatter_gather_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
175+
160176
}
161177
}; // struct _cuda_scatter_gather_internal_kernel
162178

0 commit comments

Comments
 (0)