Skip to content

Commit 64772d7

Browse files
committed
segment coo done
1 parent d0f5005 commit 64772d7

File tree

8 files changed

+394
-39
lines changed

8 files changed

+394
-39
lines changed

csrc/cuda/segment_coo_cuda.cu

Lines changed: 358 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,369 @@
11
#include "segment_coo_cuda.h"
22

3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <ATen/cuda/detail/IndexUtils.cuh>
5+
#include <ATen/cuda/detail/TensorInfo.cuh>
6+
7+
#include "reducer.cuh"
8+
#include "utils.cuh"
9+
10+
#define THREADS 256
11+
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
12+
#define FULL_MASK 0xffffffff
13+
14+
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
15+
__global__ void
16+
segment_coo_kernel(const scalar_t *src_data,
17+
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
18+
scalar_t *out_data, size_t E, size_t N) {
19+
20+
// Each thread processes exactly one entry. Within a warp, we perform a
21+
// parallel reduction across equal indices, and write the intermediate
22+
// result via atomics.
23+
24+
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
25+
int lane_idx = row_idx & (32 - 1);
26+
int D = index_info.sizes[index_info.dims - 1];
27+
28+
if (row_idx < E) {
29+
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
30+
row_idx, index_info);
31+
int64_t idx = index_info.data[offset], next_idx;
32+
int out_idx = (row_idx / D) * N + idx;
33+
34+
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
35+
36+
#pragma unroll
37+
for (int i = 1; i < 32; i *= 2) {
38+
// Parallel reduction inside a single warp.
39+
tmp = __shfl_up_sync(FULL_MASK, val, i);
40+
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
41+
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
42+
assert(idx >= next_idx);
43+
if (idx == next_idx)
44+
Reducer<scalar_t, REDUCE>::update(&val, tmp);
45+
}
46+
}
47+
48+
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
49+
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
50+
idx != next_idx)
51+
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
52+
}
53+
}
54+
55+
template <typename scalar_t>
56+
__global__ void segment_coo_arg_kernel(
57+
const scalar_t *src_data,
58+
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
59+
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
60+
61+
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
62+
int D = index_info.sizes[index_info.dims - 1];
63+
64+
if (row_idx < E) {
65+
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
66+
row_idx, index_info);
67+
int64_t idx = index_info.data[offset];
68+
int out_idx = (row_idx / D) * N + idx;
69+
70+
scalar_t val = __ldg(out_data + out_idx);
71+
if (src_data[row_idx] == val)
72+
arg_out_data[out_idx] = row_idx % D;
73+
}
74+
}
75+
76+
template <typename scalar_t, ReductionType REDUCE, int TB>
77+
__global__ void segment_coo_broadcast_kernel(
78+
const scalar_t *src_data,
79+
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
80+
scalar_t *out_data, size_t E, size_t K, size_t N) {
81+
82+
// Each thread processes a single column and `TB` index entries. Coalesced
83+
// read and write is performed in column-major order. The intermediate
84+
// results are written via atomics.
85+
86+
int D = index_info.sizes[index_info.dims - 1];
87+
int E_1 = E / D;
88+
int E_2 = D + TB - (D % TB);
89+
90+
int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
91+
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
92+
93+
int dim_start = (row_idx * TB) / E_2;
94+
int row_start = (row_idx * TB) % E_2;
95+
96+
if (dim_start < E_1 && col_idx < K) {
97+
98+
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
99+
dim_start * D + row_start, index_info);
100+
int idx1 = __ldg(index_info.data + offset), idx2;
101+
102+
scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];
103+
104+
#pragma unroll
105+
for (int i = 1; i < TB; i++) {
106+
if (row_start + i >= D)
107+
break;
108+
109+
idx2 = __ldg(index_info.data + offset +
110+
i * index_info.strides[index_info.dims - 1]);
111+
assert(idx1 <= idx2);
112+
if (idx1 == idx2) {
113+
Reducer<scalar_t, REDUCE>::update(
114+
&val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
115+
} else {
116+
Reducer<scalar_t, REDUCE>::atomic_write(
117+
out_data + (dim_start * N + idx1) * K + col_idx, val);
118+
val = src_data[K * (dim_start * D + row_start + i) + col_idx];
119+
}
120+
121+
idx1 = idx2;
122+
}
123+
124+
Reducer<scalar_t, REDUCE>::atomic_write(
125+
out_data + (dim_start * N + idx1) * K + col_idx, val);
126+
}
127+
}
128+
129+
template <typename scalar_t>
130+
__global__ void segment_coo_arg_broadcast_kernel(
131+
const scalar_t *src_data,
132+
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
133+
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {
134+
135+
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
136+
int row_idx = thread_idx / K;
137+
int col_idx = thread_idx % K;
138+
int D = index_info.sizes[index_info.dims - 1];
139+
140+
if (row_idx < E && col_idx < K) {
141+
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
142+
row_idx, index_info);
143+
int idx = __ldg(index_info.data + offset);
144+
int out_idx = ((row_idx / D) * N + idx) * K + col_idx;
145+
146+
scalar_t val = __ldg(out_data + out_idx);
147+
if (src_data[thread_idx] == val)
148+
arg_out_data[out_idx] = row_idx % D;
149+
}
150+
}
151+
3152
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
4153
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
5154
torch::optional<torch::Tensor> optional_out,
6155
torch::optional<int64_t> dim_size, std::string reduce) {
7-
return std::make_tuple(src, optional_out);
156+
CHECK_CUDA(src);
157+
CHECK_CUDA(index);
158+
if (optional_out.has_value())
159+
CHECK_CUDA(optional_out.value());
160+
cudaSetDevice(src.get_device());
161+
162+
CHECK_INPUT(src.dim() >= index.dim());
163+
164+
auto sizes = index.sizes().vec();
165+
for (int i = 0; i < index.dim(); i++) {
166+
sizes[i] = src.size(i);
167+
}
168+
index = index.expand(sizes);
169+
170+
auto dim = index.dim() - 1;
171+
172+
src = src.contiguous();
173+
174+
torch::Tensor out;
175+
if (optional_out.has_value()) {
176+
out = optional_out.value().contiguous();
177+
for (int i = 0; i < out.dim(); i++)
178+
if (i != dim)
179+
CHECK_INPUT(src.size(i) == out.size(i));
180+
} else {
181+
sizes = src.sizes().vec();
182+
if (dim_size.has_value())
183+
sizes[dim] = dim_size.value();
184+
else {
185+
auto d_size = index.max().data_ptr<int64_t>();
186+
auto h_size = (int64_t *)malloc(sizeof(int64_t));
187+
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
188+
sizes[dim] = 1 + *h_size;
189+
}
190+
out = torch::zeros(sizes, src.options());
191+
}
192+
193+
torch::optional<torch::Tensor> arg_out = torch::nullopt;
194+
int64_t *arg_out_data = nullptr;
195+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
196+
arg_out = torch::full_like(out, src.size(dim), index.options());
197+
arg_out_data = arg_out.value().data_ptr<int64_t>();
198+
}
199+
200+
auto E = index.numel();
201+
auto E_2 = index.size(dim);
202+
auto E_1 = index.numel() / E_2;
203+
auto K = src.numel() / E;
204+
auto N = out.size(dim);
205+
auto avg_len = (float)E_2 / (float)N;
206+
207+
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
208+
auto stream = at::cuda::getCurrentCUDAStream();
209+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
210+
auto src_data = src.data_ptr<scalar_t>();
211+
auto out_data = out.data_ptr<scalar_t>();
212+
213+
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
214+
if (!optional_out.has_value())
215+
out.fill_(Reducer<scalar_t, REDUCE>::init());
216+
217+
if (K == 1)
218+
segment_coo_kernel<scalar_t, REDUCE, true>
219+
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
220+
out_data, E, N);
221+
else if (avg_len <= 8)
222+
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
223+
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
224+
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
225+
N);
226+
else if (avg_len <= 16)
227+
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
228+
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
229+
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
230+
N);
231+
else if (avg_len <= 32)
232+
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
233+
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
234+
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
235+
N);
236+
else
237+
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
238+
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
239+
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
240+
N);
241+
242+
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
243+
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
244+
245+
if (REDUCE == MIN || REDUCE == MAX) {
246+
if (K == 1)
247+
segment_coo_arg_kernel<scalar_t>
248+
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
249+
src_data, index_info, out_data, arg_out_data, E, N);
250+
else
251+
segment_coo_arg_broadcast_kernel<scalar_t>
252+
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
253+
src_data, index_info, out_data, arg_out_data, E, K, N);
254+
}
255+
256+
if (REDUCE == MEAN) {
257+
auto sizes = index.sizes().vec();
258+
sizes[dim] = out.size(dim);
259+
auto count = torch::zeros(sizes, out.options());
260+
auto count_data = count.data_ptr<scalar_t>();
261+
segment_coo_kernel<scalar_t, SUM, false>
262+
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
263+
count_data, E, N);
264+
arg_out = count;
265+
for (int i = dim + 1; i < out.dim(); i++)
266+
count = count.unsqueeze(-1);
267+
out.div_(count.clamp_(1));
268+
}
269+
});
270+
});
271+
272+
return std::make_tuple(out, arg_out);
273+
}
274+
275+
template <typename scalar_t>
276+
__global__ void
277+
gather_coo_kernel(const scalar_t *src_data,
278+
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
279+
scalar_t *out_data, size_t E, size_t N) {
280+
281+
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
282+
283+
if (row_idx < E) {
284+
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
285+
row_idx, index_info);
286+
int row = index_info.data[offset];
287+
288+
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
289+
scalar_t val = __ldg(src_data + offset + row);
290+
291+
out_data[row_idx] = val;
292+
}
293+
}
294+
295+
template <typename scalar_t>
296+
__global__ void gather_coo_broadcast_kernel(
297+
const scalar_t *src_data,
298+
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
299+
scalar_t *out_data, size_t E, size_t K, size_t N) {
300+
301+
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
302+
int row_idx = thread_idx / K;
303+
int col_idx = thread_idx % K;
304+
305+
if (thread_idx < E * K) {
306+
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
307+
row_idx, index_info);
308+
int row = index_info.data[offset];
309+
310+
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
311+
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
312+
313+
out_data[thread_idx] = val;
314+
}
8315
}
9316
10317
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
11318
torch::optional<torch::Tensor> optional_out) {
12-
return src;
319+
CHECK_CUDA(src);
320+
CHECK_CUDA(index);
321+
if (optional_out.has_value())
322+
CHECK_CUDA(optional_out.value());
323+
cudaSetDevice(src.get_device());
324+
325+
CHECK_INPUT(src.dim() >= index.dim());
326+
327+
auto sizes = index.sizes().vec();
328+
for (auto i = 0; i < index.dim() - 1; i++)
329+
sizes[i] = src.size(i);
330+
index = index.expand(sizes);
331+
332+
auto dim = index.dim() - 1;
333+
334+
src = src.contiguous();
335+
336+
torch::Tensor out;
337+
if (optional_out.has_value()) {
338+
out = optional_out.value().contiguous();
339+
for (auto i = 0; i < src.dim(); i++)
340+
if (i != dim)
341+
CHECK_INPUT(src.size(i) == out.size(i));
342+
CHECK_INPUT(index.size(dim) == out.size(dim));
343+
} else {
344+
auto sizes = src.sizes().vec();
345+
sizes[dim] = index.size(dim);
346+
out = torch::empty(sizes, src.options());
347+
}
348+
349+
auto E = index.numel();
350+
auto K = out.numel() / E;
351+
auto N = src.size(dim);
352+
353+
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
354+
auto stream = at::cuda::getCurrentCUDAStream();
355+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
356+
auto src_data = src.data_ptr<scalar_t>();
357+
auto out_data = out.data_ptr<scalar_t>();
358+
359+
if (K == 1)
360+
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
361+
src_data, index_info, out_data, E, N);
362+
else
363+
gather_coo_broadcast_kernel<scalar_t>
364+
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
365+
out_data, E, K, N);
366+
});
367+
368+
return out;
13369
}

csrc/cuda/segment_csr_cuda.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
237237
if (i != dim)
238238
CHECK_INPUT(src.size(i) == out.size(i));
239239
} else {
240-
auto d_gather_size = indptr.flatten()[-1].data_ptr<int64_t>();
241-
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
242-
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
243-
cudaMemcpyDeviceToHost);
244-
240+
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
241+
auto h_size = (int64_t *)malloc(sizeof(int64_t));
242+
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
245243
auto sizes = src.sizes().vec();
246-
sizes[dim] = *h_gather_size;
244+
sizes[dim] = *h_size;
247245
out = torch::empty(sizes, src.options());
248246
}
249247

0 commit comments

Comments
 (0)