Skip to content

Commit 48024c1

Browse files
committed
added cpu segment implementation
1 parent 5817fb9 commit 48024c1

File tree

5 files changed

+287
-13
lines changed

5 files changed

+287
-13
lines changed

benchmark/scatter_segment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
1212
from torch_scatter import segment_coo, segment_csr
1313

14-
iters = 20
15-
sizes = [1, 16, 32, 64, 128, 256, 512]
16-
1714
short_rows = [
1815
('DIMACS10', 'citationCiteseer'),
1916
('SNAP', 'web-Stanford'),
@@ -216,6 +213,9 @@ def dense2(x):
216213
parser.add_argument('--device', type=str, default='cuda')
217214
args = parser.parse_args()
218215
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
216+
iters = 1 if args.device == 'cpu' else 20
217+
sizes = [1, 16, 32, 64, 128, 256, 512]
218+
sizes = sizes[:3] if args.device == 'cpu' else sizes
219219

220220
for _ in range(10): # Warmup.
221221
torch.randn(100, 100, device=args.device).sum()

cpu/index_info.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
#include "compat.h"
6+
7+
#define MAX_TENSORINFO_DIMS 25
8+
9+
template <typename scalar_t> struct TensorInfo {
10+
TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
11+
int st[MAX_TENSORINFO_DIMS]) {
12+
data = p;
13+
dims = dim;
14+
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);
15+
16+
for (int i = 0; i < dim; ++i) {
17+
sizes[i] = sz[i];
18+
strides[i] = st[i];
19+
}
20+
}
21+
22+
scalar_t *data;
23+
int dims;
24+
int sizes[MAX_TENSORINFO_DIMS];
25+
int strides[MAX_TENSORINFO_DIMS];
26+
};
27+
28+
template <typename scalar_t>
29+
TensorInfo<scalar_t> getTensorInfo(const at::Tensor &tensor) {
30+
int sizes[MAX_TENSORINFO_DIMS];
31+
int strides[MAX_TENSORINFO_DIMS];
32+
33+
int dims = tensor.dim();
34+
for (int i = 0; i < dims; ++i) {
35+
sizes[i] = tensor.size(i);
36+
strides[i] = tensor.stride(i);
37+
}
38+
39+
return TensorInfo<scalar_t>(tensor.DATA_PTR<scalar_t>(), dims, sizes,
40+
strides);
41+
}
42+
43+
template <typename scalar_t> struct IndexToOffset {
44+
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
45+
int offset = 0;
46+
for (int i = info.dims - 1; i >= 0; --i) {
47+
offset += (idx % info.sizes[i]) * info.strides[i];
48+
idx /= info.sizes[i];
49+
}
50+
return offset;
51+
}
52+
};
53+
54+
template <typename scalar_t> struct IndexPtrToOffset {
55+
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
56+
int offset = idx % (info.sizes[info.dims - 1] - 1);
57+
offset *= info.strides[info.dims - 1];
58+
idx /= info.sizes[info.dims - 1] - 1;
59+
for (int i = info.dims - 2; i >= 0; --i) {
60+
offset += (idx % info.sizes[i]) * info.strides[i];
61+
idx /= info.sizes[i];
62+
}
63+
return offset;
64+
}
65+
};

cpu/segment.cpp

Lines changed: 216 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,152 @@
11
#include <torch/extension.h>
22

3+
#include "compat.h"
4+
#include "index_info.h"
5+
36
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
47

8+
enum ReductionType { ADD, MEAN, MIN, MAX };
9+
10+
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
11+
[&] { \
12+
if (reduce == "add") { \
13+
const ReductionType REDUCE = ADD; \
14+
return __VA_ARGS__(); \
15+
} else if (reduce == "mean") { \
16+
const ReductionType REDUCE = MEAN; \
17+
return __VA_ARGS__(); \
18+
} else if (reduce == "min") { \
19+
const ReductionType REDUCE = MIN; \
20+
return __VA_ARGS__(); \
21+
} else if (reduce == "max") { \
22+
const ReductionType REDUCE = MAX; \
23+
return __VA_ARGS__(); \
24+
} \
25+
}()
26+
27+
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
28+
static inline scalar_t init() {
29+
if (REDUCE == MIN) {
30+
return std::numeric_limits<scalar_t>::max();
31+
} else if (REDUCE == MAX) {
32+
return std::numeric_limits<scalar_t>::lowest();
33+
} else {
34+
return (scalar_t)0;
35+
}
36+
}
37+
38+
static inline void update(scalar_t *val, scalar_t new_val) {
39+
if (REDUCE == ADD || REDUCE == MEAN) {
40+
*val = *val + new_val;
41+
} else if ((REDUCE == MIN && new_val < *val) ||
42+
(REDUCE == MAX && new_val > *val)) {
43+
*val = new_val;
44+
}
45+
}
46+
47+
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
48+
int64_t new_arg) {
49+
if (REDUCE == ADD || REDUCE == MEAN) {
50+
*val = *val + new_val;
51+
} else if ((REDUCE == MIN && new_val < *val) ||
52+
(REDUCE == MAX && new_val > *val)) {
53+
*val = new_val;
54+
*arg = new_arg;
55+
}
56+
}
57+
58+
static inline void write(scalar_t *address, scalar_t val,
59+
int64_t *arg_address, int64_t arg, int count) {
60+
if (REDUCE == ADD) {
61+
*address = val;
62+
} else if (REDUCE == MEAN) {
63+
*address = val / (count > 0 ? count : (scalar_t)1);
64+
} else if (REDUCE == MIN || REDUCE == MAX) {
65+
if (count > 0) {
66+
*address = val;
67+
*arg_address = arg;
68+
} else {
69+
*address = (scalar_t)0;
70+
}
71+
}
72+
}
73+
};
74+
575
std::tuple<at::Tensor, at::optional<at::Tensor>>
676
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
777
std::string reduce) {
878
CHECK_CPU(src);
979
CHECK_CPU(indptr);
1080
if (out_opt.has_value())
1181
CHECK_CPU(out_opt.value());
12-
AT_ASSERTM(false, "Not yet implemented");
13-
return std::make_tuple(src, at::nullopt);
82+
83+
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
84+
85+
// Broadcasting `indptr` via `expand`.
86+
auto sizes = indptr.sizes().vec();
87+
for (int i = 0; i < indptr.dim() - 1; i++) {
88+
sizes[i] = src.size(i);
89+
}
90+
indptr = indptr.expand(sizes);
91+
92+
src = src.contiguous();
93+
auto reduce_dim = indptr.dim() - 1;
94+
95+
at::Tensor out;
96+
if (out_opt.has_value()) {
97+
out = out_opt.value().contiguous();
98+
for (int i = 0; i < out.dim(); i++)
99+
if (i != reduce_dim)
100+
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
101+
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
102+
"Input mismatch");
103+
} else {
104+
sizes = src.sizes().vec();
105+
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
106+
out = at::empty(sizes, src.options());
107+
}
108+
109+
at::optional<at::Tensor> arg_out = at::nullopt;
110+
int64_t *arg_out_data = nullptr;
111+
if (reduce == "min" || reduce == "max") {
112+
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
113+
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
114+
}
115+
116+
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
117+
auto K = out.numel() / N;
118+
auto E = src.size(reduce_dim);
119+
120+
auto indptr_info = getTensorInfo<int64_t>(indptr);
121+
auto stride = indptr_info.strides[indptr_info.dims - 1];
122+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
123+
auto src_data = src.DATA_PTR<scalar_t>();
124+
auto out_data = out.DATA_PTR<scalar_t>();
125+
126+
scalar_t val;
127+
int64_t row_start, row_end, arg;
128+
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
129+
for (int n = 0; n < N; n++) {
130+
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
131+
row_start = indptr_info.data[offset];
132+
row_end = indptr_info.data[offset + stride];
133+
134+
offset = (n / (indptr.size(-1) - 1)) * E * K;
135+
for (int k = 0; k < K; k++) {
136+
val = Reducer<scalar_t, REDUCE>::init();
137+
for (int64_t e = row_start; e < row_end; e++) {
138+
Reducer<scalar_t, REDUCE>::update(
139+
&val, src_data[offset + e * K + k], &arg, e);
140+
}
141+
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, val,
142+
arg_out_data + n * K + k, arg,
143+
row_end - row_start);
144+
}
145+
}
146+
});
147+
});
148+
149+
return std::make_tuple(out, arg_out);
14150
}
15151

16152
std::tuple<at::Tensor, at::optional<at::Tensor>>
@@ -19,8 +155,84 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
19155
CHECK_CPU(src);
20156
CHECK_CPU(index);
21157
CHECK_CPU(out);
22-
AT_ASSERTM(false, "Not yet implemented");
23-
return std::make_tuple(src, at::nullopt);
158+
159+
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
160+
161+
// Broadcasting `index` via `expand`.
162+
auto sizes = index.sizes().vec();
163+
for (int i = 0; i < index.dim(); i++) {
164+
sizes[i] = src.size(i);
165+
}
166+
index = index.expand(sizes);
167+
168+
src = src.contiguous();
169+
out = out.contiguous();
170+
auto reduce_dim = index.dim() - 1;
171+
172+
for (int i = 0; i < out.dim(); i++)
173+
if (i != reduce_dim)
174+
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
175+
176+
at::optional<at::Tensor> arg_out = at::nullopt;
177+
int64_t *arg_out_data = nullptr;
178+
if (reduce == "min" || reduce == "max") {
179+
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
180+
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
181+
}
182+
183+
auto E_1 = index.numel() / src.size(reduce_dim);
184+
auto E_2 = src.size(reduce_dim);
185+
auto K = src.numel() / index.numel();
186+
auto N = out.size(reduce_dim);
187+
188+
auto index_info = getTensorInfo<int64_t>(index);
189+
auto stride = index_info.strides[index_info.dims - 1];
190+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
191+
auto src_data = src.DATA_PTR<scalar_t>();
192+
auto out_data = out.DATA_PTR<scalar_t>();
193+
194+
scalar_t val;
195+
int64_t idx, next_idx, row_start, arg;
196+
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
197+
for (int e_1 = 0; e_1 < E_1; e_1++) {
198+
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
199+
200+
for (int k = 0; k < K; k++) {
201+
idx = index_info.data[offset];
202+
row_start = 0;
203+
val = out_data[e_1 * N * K + k];
204+
205+
for (int e_2 = 0; e_2 < E_2; e_2++) {
206+
Reducer<scalar_t, REDUCE>::update(
207+
&val, src_data[e_1 * E_2 * K + e_2 * K + k], &arg, e_2);
208+
209+
if (e_2 == E_2 - 1) {
210+
Reducer<scalar_t, REDUCE>::write(
211+
out_data + e_1 * N * K + idx * K + k, val,
212+
arg_out_data + e_1 * N * K + idx * K + k, arg,
213+
e_2 + 1 - row_start);
214+
} else {
215+
next_idx = index_info.data[offset + (e_2 + 1) * stride];
216+
217+
if (idx != next_idx) {
218+
Reducer<scalar_t, REDUCE>::write(
219+
out_data + e_1 * N * K + idx * K + k, val,
220+
arg_out_data + e_1 * N * K + idx * K + k, arg,
221+
e_2 + 1 - row_start);
222+
223+
row_start = e_2 + 1;
224+
val = out_data[e_1 * N * K + next_idx * K + k];
225+
}
226+
227+
idx = next_idx;
228+
}
229+
}
230+
}
231+
}
232+
});
233+
});
234+
235+
return std::make_tuple(out, arg_out);
24236
}
25237

26238
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

cuda/segment_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
178178

179179
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
180180

181-
// Broadcasting across `index` via `expand`.
181+
// Broadcasting `indptr` via `expand`.
182182
auto sizes = indptr.sizes().vec();
183183
for (int i = 0; i < indptr.dim() - 1; i++) {
184184
sizes[i] = src.size(i);
@@ -379,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
379379

380380
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
381381

382-
// Broadcasting across `index` via `expand`.
382+
// Broadcasting `index` via `expand`.
383383
auto sizes = index.sizes().vec();
384384
for (int i = 0; i < index.dim(); i++) {
385385
sizes[i] = src.size(i);

test/test_segment.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
reductions = ['add', 'mean', 'min', 'max']
1111
grad_reductions = ['add', 'mean']
1212

13-
devices = [torch.device('cuda')]
13+
devices = [torch.device('cpu')]
1414

1515
tests = [
1616
{
@@ -82,7 +82,6 @@
8282
]
8383

8484

85-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
8685
@pytest.mark.parametrize('test,reduce,dtype,device',
8786
product(tests, reductions, dtypes, devices))
8887
def test_forward(test, reduce, dtype, device):
@@ -119,7 +118,6 @@ def test_backward(test, reduce, device):
119118
assert gradcheck(segment_csr, (src, indptr, None, reduce)) is True
120119

121120

122-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
123121
@pytest.mark.parametrize('test,reduce,dtype,device',
124122
product(tests, reductions, dtypes, devices))
125123
def test_segment_out(test, reduce, dtype, device):
@@ -153,7 +151,6 @@ def test_segment_out(test, reduce, dtype, device):
153151
assert torch.all(out == expected)
154152

155153

156-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
157154
@pytest.mark.parametrize('test,reduce,dtype,device',
158155
product(tests, reductions, dtypes, devices))
159156
def test_non_contiguous_segment(test, reduce, dtype, device):

0 commit comments

Comments
 (0)