|
1 | 1 | #include "segment_coo_cuda.h" |
2 | 2 |
|
| 3 | +#include <ATen/cuda/CUDAContext.h> |
| 4 | +#include <ATen/cuda/detail/IndexUtils.cuh> |
| 5 | +#include <ATen/cuda/detail/TensorInfo.cuh> |
| 6 | + |
| 7 | +#include "reducer.cuh" |
| 8 | +#include "utils.cuh" |
| 9 | + |
| 10 | +#define THREADS 256 |
| 11 | +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS |
| 12 | +#define FULL_MASK 0xffffffff |
| 13 | + |
| 14 | +template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL> |
| 15 | +__global__ void |
| 16 | +segment_coo_kernel(const scalar_t *src_data, |
| 17 | + const at::cuda::detail::TensorInfo<int64_t, int> index_info, |
| 18 | + scalar_t *out_data, size_t E, size_t N) { |
| 19 | + |
| 20 | + // Each thread processes exactly one entry. Within a warp, we perform a |
| 21 | + // parallel reduction across equal indices, and write the intermediate |
| 22 | + // result via atomics. |
| 23 | + |
| 24 | + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 25 | + int lane_idx = row_idx & (32 - 1); |
| 26 | + int D = index_info.sizes[index_info.dims - 1]; |
| 27 | + |
| 28 | + if (row_idx < E) { |
| 29 | + int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( |
| 30 | + row_idx, index_info); |
| 31 | + int64_t idx = index_info.data[offset], next_idx; |
| 32 | + int out_idx = (row_idx / D) * N + idx; |
| 33 | + |
| 34 | + scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; |
| 35 | + |
| 36 | +#pragma unroll |
| 37 | + for (int i = 1; i < 32; i *= 2) { |
| 38 | + // Parallel reduction inside a single warp. |
| 39 | + tmp = __shfl_up_sync(FULL_MASK, val, i); |
| 40 | + next_idx = __shfl_up_sync(FULL_MASK, idx, i); |
| 41 | + if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { |
| 42 | + assert(idx >= next_idx); |
| 43 | + if (idx == next_idx) |
| 44 | + Reducer<scalar_t, REDUCE>::update(&val, tmp); |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + next_idx = __shfl_down_sync(FULL_MASK, idx, 1); |
| 49 | + if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || |
| 50 | + idx != next_idx) |
| 51 | + Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val); |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +template <typename scalar_t> |
| 56 | +__global__ void segment_coo_arg_kernel( |
| 57 | + const scalar_t *src_data, |
| 58 | + const at::cuda::detail::TensorInfo<int64_t, int> index_info, |
| 59 | + scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) { |
| 60 | + |
| 61 | + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 62 | + int D = index_info.sizes[index_info.dims - 1]; |
| 63 | + |
| 64 | + if (row_idx < E) { |
| 65 | + int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( |
| 66 | + row_idx, index_info); |
| 67 | + int64_t idx = index_info.data[offset]; |
| 68 | + int out_idx = (row_idx / D) * N + idx; |
| 69 | + |
| 70 | + scalar_t val = __ldg(out_data + out_idx); |
| 71 | + if (src_data[row_idx] == val) |
| 72 | + arg_out_data[out_idx] = row_idx % D; |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +template <typename scalar_t, ReductionType REDUCE, int TB> |
| 77 | +__global__ void segment_coo_broadcast_kernel( |
| 78 | + const scalar_t *src_data, |
| 79 | + const at::cuda::detail::TensorInfo<int64_t, int> index_info, |
| 80 | + scalar_t *out_data, size_t E, size_t K, size_t N) { |
| 81 | + |
| 82 | + // Each thread processes a single column and `TB` index entries. Coalesced |
| 83 | + // read and write is performed in column-major order. The intermediate |
| 84 | + // results are written via atomics. |
| 85 | + |
| 86 | + int D = index_info.sizes[index_info.dims - 1]; |
| 87 | + int E_1 = E / D; |
| 88 | + int E_2 = D + TB - (D % TB); |
| 89 | + |
| 90 | + int row_idx = blockIdx.x * blockDim.y + threadIdx.y; |
| 91 | + int col_idx = blockIdx.y * blockDim.x + threadIdx.x; |
| 92 | + |
| 93 | + int dim_start = (row_idx * TB) / E_2; |
| 94 | + int row_start = (row_idx * TB) % E_2; |
| 95 | + |
| 96 | + if (dim_start < E_1 && col_idx < K) { |
| 97 | + |
| 98 | + int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( |
| 99 | + dim_start * D + row_start, index_info); |
| 100 | + int idx1 = __ldg(index_info.data + offset), idx2; |
| 101 | + |
| 102 | + scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx]; |
| 103 | + |
| 104 | +#pragma unroll |
| 105 | + for (int i = 1; i < TB; i++) { |
| 106 | + if (row_start + i >= D) |
| 107 | + break; |
| 108 | + |
| 109 | + idx2 = __ldg(index_info.data + offset + |
| 110 | + i * index_info.strides[index_info.dims - 1]); |
| 111 | + assert(idx1 <= idx2); |
| 112 | + if (idx1 == idx2) { |
| 113 | + Reducer<scalar_t, REDUCE>::update( |
| 114 | + &val, src_data[K * (dim_start * D + row_start + i) + col_idx]); |
| 115 | + } else { |
| 116 | + Reducer<scalar_t, REDUCE>::atomic_write( |
| 117 | + out_data + (dim_start * N + idx1) * K + col_idx, val); |
| 118 | + val = src_data[K * (dim_start * D + row_start + i) + col_idx]; |
| 119 | + } |
| 120 | + |
| 121 | + idx1 = idx2; |
| 122 | + } |
| 123 | + |
| 124 | + Reducer<scalar_t, REDUCE>::atomic_write( |
| 125 | + out_data + (dim_start * N + idx1) * K + col_idx, val); |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +template <typename scalar_t> |
| 130 | +__global__ void segment_coo_arg_broadcast_kernel( |
| 131 | + const scalar_t *src_data, |
| 132 | + const at::cuda::detail::TensorInfo<int64_t, int> index_info, |
| 133 | + scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) { |
| 134 | + |
| 135 | + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 136 | + int row_idx = thread_idx / K; |
| 137 | + int col_idx = thread_idx % K; |
| 138 | + int D = index_info.sizes[index_info.dims - 1]; |
| 139 | + |
| 140 | + if (row_idx < E && col_idx < K) { |
| 141 | + int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( |
| 142 | + row_idx, index_info); |
| 143 | + int idx = __ldg(index_info.data + offset); |
| 144 | + int out_idx = ((row_idx / D) * N + idx) * K + col_idx; |
| 145 | + |
| 146 | + scalar_t val = __ldg(out_data + out_idx); |
| 147 | + if (src_data[thread_idx] == val) |
| 148 | + arg_out_data[out_idx] = row_idx % D; |
| 149 | + } |
| 150 | +} |
| 151 | + |
3 | 152 | std::tuple<torch::Tensor, torch::optional<torch::Tensor>> |
4 | 153 | segment_coo_cuda(torch::Tensor src, torch::Tensor index, |
5 | 154 | torch::optional<torch::Tensor> optional_out, |
6 | 155 | torch::optional<int64_t> dim_size, std::string reduce) { |
7 | | - return std::make_tuple(src, optional_out); |
| 156 | + CHECK_CUDA(src); |
| 157 | + CHECK_CUDA(index); |
| 158 | + if (optional_out.has_value()) |
| 159 | + CHECK_CUDA(optional_out.value()); |
| 160 | + cudaSetDevice(src.get_device()); |
| 161 | + |
| 162 | + CHECK_INPUT(src.dim() >= index.dim()); |
| 163 | + |
| 164 | + auto sizes = index.sizes().vec(); |
| 165 | + for (int i = 0; i < index.dim(); i++) { |
| 166 | + sizes[i] = src.size(i); |
| 167 | + } |
| 168 | + index = index.expand(sizes); |
| 169 | + |
| 170 | + auto dim = index.dim() - 1; |
| 171 | + |
| 172 | + src = src.contiguous(); |
| 173 | + |
| 174 | + torch::Tensor out; |
| 175 | + if (optional_out.has_value()) { |
| 176 | + out = optional_out.value().contiguous(); |
| 177 | + for (int i = 0; i < out.dim(); i++) |
| 178 | + if (i != dim) |
| 179 | + CHECK_INPUT(src.size(i) == out.size(i)); |
| 180 | + } else { |
| 181 | + sizes = src.sizes().vec(); |
| 182 | + if (dim_size.has_value()) |
| 183 | + sizes[dim] = dim_size.value(); |
| 184 | + else { |
| 185 | + auto d_size = index.max().data_ptr<int64_t>(); |
| 186 | + auto h_size = (int64_t *)malloc(sizeof(int64_t)); |
| 187 | + cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); |
| 188 | + sizes[dim] = 1 + *h_size; |
| 189 | + } |
| 190 | + out = torch::zeros(sizes, src.options()); |
| 191 | + } |
| 192 | + |
| 193 | + torch::optional<torch::Tensor> arg_out = torch::nullopt; |
| 194 | + int64_t *arg_out_data = nullptr; |
| 195 | + if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { |
| 196 | + arg_out = torch::full_like(out, src.size(dim), index.options()); |
| 197 | + arg_out_data = arg_out.value().data_ptr<int64_t>(); |
| 198 | + } |
| 199 | + |
| 200 | + auto E = index.numel(); |
| 201 | + auto E_2 = index.size(dim); |
| 202 | + auto E_1 = index.numel() / E_2; |
| 203 | + auto K = src.numel() / E; |
| 204 | + auto N = out.size(dim); |
| 205 | + auto avg_len = (float)E_2 / (float)N; |
| 206 | + |
| 207 | + auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); |
| 208 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 209 | + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { |
| 210 | + auto src_data = src.data_ptr<scalar_t>(); |
| 211 | + auto out_data = out.data_ptr<scalar_t>(); |
| 212 | + |
| 213 | + AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { |
| 214 | + if (!optional_out.has_value()) |
| 215 | + out.fill_(Reducer<scalar_t, REDUCE>::init()); |
| 216 | + |
| 217 | + if (K == 1) |
| 218 | + segment_coo_kernel<scalar_t, REDUCE, true> |
| 219 | + <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, |
| 220 | + out_data, E, N); |
| 221 | + else if (avg_len <= 8) |
| 222 | + segment_coo_broadcast_kernel<scalar_t, REDUCE, 4> |
| 223 | + <<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32), |
| 224 | + dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, |
| 225 | + N); |
| 226 | + else if (avg_len <= 16) |
| 227 | + segment_coo_broadcast_kernel<scalar_t, REDUCE, 8> |
| 228 | + <<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32), |
| 229 | + dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, |
| 230 | + N); |
| 231 | + else if (avg_len <= 32) |
| 232 | + segment_coo_broadcast_kernel<scalar_t, REDUCE, 16> |
| 233 | + <<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32), |
| 234 | + dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, |
| 235 | + N); |
| 236 | + else |
| 237 | + segment_coo_broadcast_kernel<scalar_t, REDUCE, 32> |
| 238 | + <<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32), |
| 239 | + dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K, |
| 240 | + N); |
| 241 | +
|
| 242 | + if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) |
| 243 | + out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0); |
| 244 | +
|
| 245 | + if (REDUCE == MIN || REDUCE == MAX) { |
| 246 | + if (K == 1) |
| 247 | + segment_coo_arg_kernel<scalar_t> |
| 248 | + <<<BLOCKS(1, E), THREADS, 0, stream>>>( |
| 249 | + src_data, index_info, out_data, arg_out_data, E, N); |
| 250 | + else |
| 251 | + segment_coo_arg_broadcast_kernel<scalar_t> |
| 252 | + <<<BLOCKS(1, E * K), THREADS, 0, stream>>>( |
| 253 | + src_data, index_info, out_data, arg_out_data, E, K, N); |
| 254 | + } |
| 255 | +
|
| 256 | + if (REDUCE == MEAN) { |
| 257 | + auto sizes = index.sizes().vec(); |
| 258 | + sizes[dim] = out.size(dim); |
| 259 | + auto count = torch::zeros(sizes, out.options()); |
| 260 | + auto count_data = count.data_ptr<scalar_t>(); |
| 261 | + segment_coo_kernel<scalar_t, SUM, false> |
| 262 | + <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info, |
| 263 | + count_data, E, N); |
| 264 | + arg_out = count; |
| 265 | + for (int i = dim + 1; i < out.dim(); i++) |
| 266 | + count = count.unsqueeze(-1); |
| 267 | + out.div_(count.clamp_(1)); |
| 268 | + } |
| 269 | + }); |
| 270 | + }); |
| 271 | +
|
| 272 | + return std::make_tuple(out, arg_out); |
| 273 | +} |
| 274 | +
|
| 275 | +template <typename scalar_t> |
| 276 | +__global__ void |
| 277 | +gather_coo_kernel(const scalar_t *src_data, |
| 278 | + const at::cuda::detail::TensorInfo<int64_t, int> index_info, |
| 279 | + scalar_t *out_data, size_t E, size_t N) { |
| 280 | +
|
| 281 | + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 282 | +
|
| 283 | + if (row_idx < E) { |
| 284 | + int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( |
| 285 | + row_idx, index_info); |
| 286 | + int row = index_info.data[offset]; |
| 287 | +
|
| 288 | + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N; |
| 289 | + scalar_t val = __ldg(src_data + offset + row); |
| 290 | +
|
| 291 | + out_data[row_idx] = val; |
| 292 | + } |
| 293 | +} |
| 294 | +
|
| 295 | +template <typename scalar_t> |
| 296 | +__global__ void gather_coo_broadcast_kernel( |
| 297 | + const scalar_t *src_data, |
| 298 | + const at::cuda::detail::TensorInfo<int64_t, int> index_info, |
| 299 | + scalar_t *out_data, size_t E, size_t K, size_t N) { |
| 300 | +
|
| 301 | + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 302 | + int row_idx = thread_idx / K; |
| 303 | + int col_idx = thread_idx % K; |
| 304 | +
|
| 305 | + if (thread_idx < E * K) { |
| 306 | + int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( |
| 307 | + row_idx, index_info); |
| 308 | + int row = index_info.data[offset]; |
| 309 | +
|
| 310 | + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K; |
| 311 | + scalar_t val = __ldg(src_data + offset + K * row + col_idx); |
| 312 | +
|
| 313 | + out_data[thread_idx] = val; |
| 314 | + } |
8 | 315 | } |
9 | 316 |
|
10 | 317 | torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, |
11 | 318 | torch::optional<torch::Tensor> optional_out) { |
12 | | - return src; |
| 319 | + CHECK_CUDA(src); |
| 320 | + CHECK_CUDA(index); |
| 321 | + if (optional_out.has_value()) |
| 322 | + CHECK_CUDA(optional_out.value()); |
| 323 | + cudaSetDevice(src.get_device()); |
| 324 | +
|
| 325 | + CHECK_INPUT(src.dim() >= index.dim()); |
| 326 | +
|
| 327 | + auto sizes = index.sizes().vec(); |
| 328 | + for (auto i = 0; i < index.dim() - 1; i++) |
| 329 | + sizes[i] = src.size(i); |
| 330 | + index = index.expand(sizes); |
| 331 | +
|
| 332 | + auto dim = index.dim() - 1; |
| 333 | +
|
| 334 | + src = src.contiguous(); |
| 335 | +
|
| 336 | + torch::Tensor out; |
| 337 | + if (optional_out.has_value()) { |
| 338 | + out = optional_out.value().contiguous(); |
| 339 | + for (auto i = 0; i < src.dim(); i++) |
| 340 | + if (i != dim) |
| 341 | + CHECK_INPUT(src.size(i) == out.size(i)); |
| 342 | + CHECK_INPUT(index.size(dim) == out.size(dim)); |
| 343 | + } else { |
| 344 | + auto sizes = src.sizes().vec(); |
| 345 | + sizes[dim] = index.size(dim); |
| 346 | + out = torch::empty(sizes, src.options()); |
| 347 | + } |
| 348 | +
|
| 349 | + auto E = index.numel(); |
| 350 | + auto K = out.numel() / E; |
| 351 | + auto N = src.size(dim); |
| 352 | +
|
| 353 | + auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); |
| 354 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 355 | + AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] { |
| 356 | + auto src_data = src.data_ptr<scalar_t>(); |
| 357 | + auto out_data = out.data_ptr<scalar_t>(); |
| 358 | +
|
| 359 | + if (K == 1) |
| 360 | + gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>( |
| 361 | + src_data, index_info, out_data, E, N); |
| 362 | + else |
| 363 | + gather_coo_broadcast_kernel<scalar_t> |
| 364 | + <<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info, |
| 365 | + out_data, E, K, N); |
| 366 | + }); |
| 367 | +
|
| 368 | + return out; |
13 | 369 | } |
0 commit comments