Skip to content

Commit 3cf59da

Browse files
committed
add to sum, REDUCE to template
1 parent 7aa701b commit 3cf59da

File tree

6 files changed

+88
-83
lines changed

6 files changed

+88
-83
lines changed

benchmark/scatter_segment.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ def timing(dataset):
122122
avg_row_len = row.size(0) / dim_size
123123

124124
def sca_row(x):
125-
op = getattr(torch_scatter, f'scatter_{args.reduce}')
125+
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
126126
return op(x, row, dim=0, dim_size=dim_size)
127127

128128
def sca_col(x):
129-
op = getattr(torch_scatter, f'scatter_{args.reduce}')
129+
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
130130
return op(x, row_perm, dim=0, dim_size=dim_size)
131131

132132
def seg_coo(x):
@@ -136,10 +136,10 @@ def seg_csr(x):
136136
return segment_csr(x, rowptr, reduce=args.reduce)
137137

138138
def dense1(x):
139-
return getattr(torch, args.dense_reduce)(x, dim=-2)
139+
return getattr(torch, args.reduce)(x, dim=-2)
140140

141141
def dense2(x):
142-
return getattr(torch, args.dense_reduce)(x, dim=-1)
142+
return getattr(torch, args.reduce)(x, dim=-1)
143143

144144
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
145145

@@ -204,15 +204,12 @@ def dense2(x):
204204

205205
if __name__ == '__main__':
206206
parser = argparse.ArgumentParser()
207-
parser.add_argument(
208-
'--reduce',
209-
type=str,
210-
required=True,
211-
choices=['add', 'mean', 'min', 'max'])
207+
parser.add_argument('--reduce', type=str, required=True,
208+
choices=['sum', 'mean', 'min', 'max'])
212209
parser.add_argument('--with_backward', action='store_true')
213210
parser.add_argument('--device', type=str, default='cuda')
214211
args = parser.parse_args()
215-
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
212+
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
216213
iters = 1 if args.device == 'cpu' else 20
217214
sizes = [1, 16, 32, 64, 128, 256, 512]
218215
sizes = sizes[:3] if args.device == 'cpu' else sizes

cpu/segment.cpp

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,36 @@
77

88
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
99

10-
enum ReductionType { ADD, MEAN, MIN, MAX };
10+
enum ReductionType { SUM, MEAN, MIN, MAX };
11+
12+
const std::map<std::string, ReductionType> reduce2REDUCE = {
13+
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
14+
};
1115

1216
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
1317
[&] { \
14-
ReductionType REDUCE = ADD; \
15-
if (reduce == "add") { \
16-
REDUCE = ADD; \
18+
switch (reduce2REDUCE.at(reduce)) { \
19+
case SUM: { \
20+
const ReductionType REDUCE = SUM; \
1721
return __VA_ARGS__(); \
18-
} else if (reduce == "mean") { \
19-
REDUCE = MEAN; \
22+
} \
23+
case MEAN: { \
24+
const ReductionType REDUCE = MEAN; \
2025
return __VA_ARGS__(); \
21-
} else if (reduce == "min") { \
22-
REDUCE = MIN; \
26+
} \
27+
case MIN: { \
28+
const ReductionType REDUCE = MIN; \
2329
return __VA_ARGS__(); \
24-
} else if (reduce == "max") { \
25-
REDUCE = MAX; \
30+
} \
31+
case MAX: { \
32+
const ReductionType REDUCE = MAX; \
2633
return __VA_ARGS__(); \
2734
} \
35+
} \
2836
}()
2937

30-
template <typename scalar_t> struct Reducer {
31-
static inline scalar_t init(ReductionType REDUCE) {
38+
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
39+
static inline scalar_t init() {
3240
if (REDUCE == MIN) {
3341
return std::numeric_limits<scalar_t>::max();
3442
} else if (REDUCE == MAX) {
@@ -38,18 +46,9 @@ template <typename scalar_t> struct Reducer {
3846
}
3947
}
4048

41-
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
42-
if (REDUCE == ADD || REDUCE == MEAN) {
43-
*val = *val + new_val;
44-
} else if ((REDUCE == MIN && new_val < *val) ||
45-
(REDUCE == MAX && new_val > *val)) {
46-
*val = new_val;
47-
}
48-
}
49-
50-
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
49+
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
5150
int64_t new_arg) {
52-
if (REDUCE == ADD || REDUCE == MEAN) {
51+
if (REDUCE == SUM || REDUCE == MEAN) {
5352
*val = *val + new_val;
5453
} else if ((REDUCE == MIN && new_val < *val) ||
5554
(REDUCE == MAX && new_val > *val)) {
@@ -58,9 +57,9 @@ template <typename scalar_t> struct Reducer {
5857
}
5958
}
6059

61-
static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
60+
static inline void write(scalar_t *address, scalar_t val,
6261
int64_t *arg_address, int64_t arg, int count) {
63-
if (REDUCE == ADD) {
62+
if (REDUCE == SUM) {
6463
*address = val;
6564
} else if (REDUCE == MEAN) {
6665
*address = val / (count > 0 ? count : (scalar_t)1);
@@ -111,7 +110,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
111110

112111
at::optional<at::Tensor> arg_out = at::nullopt;
113112
int64_t *arg_out_data = nullptr;
114-
if (reduce == "min" || reduce == "max") {
113+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
115114
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
116115
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
117116
}
@@ -137,16 +136,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
137136

138137
offset = (n / (indptr.size(-1) - 1)) * E * K;
139138
for (int k = 0; k < K; k++) {
140-
vals[k] = Reducer<scalar_t>::init(REDUCE);
139+
vals[k] = Reducer<scalar_t, REDUCE>::init();
141140
}
142141
for (int64_t e = row_start; e < row_end; e++) {
143142
for (int k = 0; k < K; k++) {
144-
Reducer<scalar_t>::update(REDUCE,
143+
Reducer<scalar_t, REDUCE>::update(
145144
&vals[k], src_data[offset + e * K + k], &args[k], e);
146145
}
147146
}
148147
for (int k = 0; k < K; k++) {
149-
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
148+
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
150149
arg_out_data + n * K + k, args[k],
151150
row_end - row_start);
152151
}
@@ -183,7 +182,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
183182

184183
at::optional<at::Tensor> arg_out = at::nullopt;
185184
int64_t *arg_out_data = nullptr;
186-
if (reduce == "min" || reduce == "max") {
185+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
187186
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
188187
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
189188
}
@@ -215,13 +214,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
215214
for (int e_2 = 0; e_2 < E_2; e_2++) {
216215

217216
for (int k = 0; k < K; k++) {
218-
Reducer<scalar_t>::update(REDUCE,
217+
Reducer<scalar_t, REDUCE>::update(
219218
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
220219
}
221220

222221
if (e_2 == E_2 - 1) {
223222
for (int k = 0; k < K; k++) {
224-
Reducer<scalar_t>::write(REDUCE,
223+
Reducer<scalar_t, REDUCE>::write(
225224
out_data + e_1 * N * K + idx * K + k, vals[k],
226225
arg_out_data + e_1 * N * K + idx * K + k, args[k],
227226
e_2 + 1 - row_start);
@@ -232,7 +231,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
232231

233232
if (idx != next_idx) {
234233
for (int k = 0; k < K; k++) {
235-
Reducer<scalar_t>::write(REDUCE,
234+
Reducer<scalar_t, REDUCE>::write(
236235
out_data + e_1 * N * K + idx * K + k, vals[k],
237236
arg_out_data + e_1 * N * K + idx * K + k, args[k],
238237
e_2 + 1 - row_start);

cuda/segment_kernel.cu

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,32 @@
1111
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
1212
#define FULL_MASK 0xffffffff
1313

14-
enum ReductionType { ADD, MEAN, MIN, MAX };
14+
enum ReductionType { SUM, MEAN, MIN, MAX };
15+
16+
const std::map<std::string, ReductionType> reduce2REDUCE = {
17+
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
18+
};
1519

1620
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
1721
[&] { \
18-
if (reduce == "add") { \
19-
const ReductionType REDUCE = ADD; \
22+
switch (reduce2REDUCE.at(reduce)) { \
23+
case SUM: { \
24+
const ReductionType REDUCE = SUM; \
2025
return __VA_ARGS__(); \
21-
} else if (reduce == "mean") { \
26+
} \
27+
case MEAN: { \
2228
const ReductionType REDUCE = MEAN; \
2329
return __VA_ARGS__(); \
24-
} else if (reduce == "min") { \
30+
} \
31+
case MIN: { \
2532
const ReductionType REDUCE = MIN; \
2633
return __VA_ARGS__(); \
27-
} else if (reduce == "max") { \
34+
} \
35+
case MAX: { \
2836
const ReductionType REDUCE = MAX; \
2937
return __VA_ARGS__(); \
3038
} \
39+
} \
3140
}()
3241

3342
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
@@ -43,7 +52,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
4352

4453
static inline __host__ __device__ void update(scalar_t *val,
4554
scalar_t new_val) {
46-
if (REDUCE == ADD || REDUCE == MEAN) {
55+
if (REDUCE == SUM || REDUCE == MEAN) {
4756
*val = *val + new_val;
4857
} else if ((REDUCE == MIN && new_val < *val) ||
4958
(REDUCE == MAX && new_val > *val)) {
@@ -53,7 +62,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
5362

5463
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
5564
int64_t *arg, int64_t new_arg) {
56-
if (REDUCE == ADD || REDUCE == MEAN) {
65+
if (REDUCE == SUM || REDUCE == MEAN) {
5766
*val = *val + new_val;
5867
} else if ((REDUCE == MIN && new_val < *val) ||
5968
(REDUCE == MAX && new_val > *val)) {
@@ -65,7 +74,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
6574
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
6675
int64_t *arg_address,
6776
int64_t arg, int count) {
68-
if (REDUCE == ADD) {
77+
if (REDUCE == SUM) {
6978
*address = val;
7079
} else if (REDUCE == MEAN) {
7180
*address = val / (scalar_t)max(count, 1);
@@ -80,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
8089
}
8190

8291
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
83-
if (REDUCE == ADD || REDUCE == MEAN) {
92+
if (REDUCE == SUM || REDUCE == MEAN) {
8493
atomAdd(address, val);
8594
} else if (REDUCE == MIN && val < *address) {
8695
atomMin(address, val);
@@ -204,7 +213,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
204213

205214
at::optional<at::Tensor> arg_out = at::nullopt;
206215
int64_t *arg_out_data = nullptr;
207-
if (reduce == "min" || reduce == "max") {
216+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
208217
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
209218
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
210219
}
@@ -396,7 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
396405

397406
at::optional<at::Tensor> arg_out = at::nullopt;
398407
int64_t *arg_out_data = nullptr;
399-
if (reduce == "min" || reduce == "max") {
408+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
400409
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
401410
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
402411
}
@@ -455,14 +464,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
455464
});
456465
});
457466
458-
if (reduce == "mean") {
467+
if (reduce2REDUCE.at(reduce) == MEAN) {
459468
auto sizes = index.sizes().vec();
460469
sizes[reduce_dim] = out.size(reduce_dim);
461470
auto count = at::zeros(sizes, out.options());
462471
463472
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
464473
auto count_data = count.DATA_PTR<scalar_t>();
465-
segment_coo_kernel<scalar_t, ADD, false>
474+
segment_coo_kernel<scalar_t, SUM, false>
466475
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
467476
count_data, E, N);
468477
});

test/test_segment.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77

88
from .utils import tensor, dtypes, devices
99

10-
reductions = ['add', 'mean', 'min', 'max']
11-
grad_reductions = ['add', 'mean']
10+
reductions = ['sum', 'mean', 'min', 'max']
11+
grad_reductions = ['sum', 'mean']
1212

1313
tests = [
1414
{
1515
'src': [1, 2, 3, 4, 5, 6],
1616
'index': [0, 0, 1, 1, 1, 3],
1717
'indptr': [0, 2, 5, 5, 6],
18-
'add': [3, 12, 0, 6],
18+
'sum': [3, 12, 0, 6],
1919
'mean': [1.5, 4, 0, 6],
2020
'min': [1, 3, 0, 6],
2121
'arg_min': [0, 2, 6, 5],
@@ -26,7 +26,7 @@
2626
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
2727
'index': [0, 0, 1, 1, 1, 3],
2828
'indptr': [0, 2, 5, 5, 6],
29-
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
29+
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
3030
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
3131
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
3232
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
@@ -37,7 +37,7 @@
3737
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
3838
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
3939
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
40-
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
40+
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
4141
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
4242
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
4343
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
@@ -48,7 +48,7 @@
4848
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
4949
'index': [[0, 0, 1], [0, 2, 2]],
5050
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
51-
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
51+
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
5252
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
5353
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
5454
'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
@@ -59,7 +59,7 @@
5959
'src': [[1, 3], [2, 4]],
6060
'index': [[0, 0], [0, 0]],
6161
'indptr': [[0, 2], [0, 2]],
62-
'add': [[4], [6]],
62+
'sum': [[4], [6]],
6363
'mean': [[2], [3]],
6464
'min': [[1], [2]],
6565
'arg_min': [[0], [0]],
@@ -70,7 +70,7 @@
7070
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
7171
'index': [[0, 0], [0, 0]],
7272
'indptr': [[0, 2], [0, 2]],
73-
'add': [[[4, 4]], [[6, 6]]],
73+
'sum': [[[4, 4]], [[6, 6]]],
7474
'mean': [[[2, 2]], [[3, 3]]],
7575
'min': [[[1, 1]], [[2, 2]]],
7676
'arg_min': [[[0, 0]], [[0, 0]]],
@@ -134,7 +134,7 @@ def test_segment_out(test, reduce, dtype, device):
134134

135135
segment_coo(src, index, out, reduce=reduce)
136136

137-
if reduce == 'add':
137+
if reduce == 'sum':
138138
expected = expected - 2
139139
elif reduce == 'mean':
140140
expected = out # We can not really test this here.

torch_scatter/gather.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def backward(ctx, grad_out):
3131
grad_src = None
3232
if ctx.needs_input_grad[0]:
3333
grad_src, _ = seg(grad_out.is_cuda).segment_coo(
34-
grad_out, index, grad_out.new_zeros(src_size), 'add')
34+
grad_out, index, grad_out.new_zeros(src_size), 'sum')
3535

3636
return grad_src, None, None
3737

@@ -53,7 +53,7 @@ def backward(ctx, grad_out):
5353
grad_src = None
5454
if ctx.needs_input_grad[0]:
5555
grad_src, _ = seg(grad_out.is_cuda).segment_csr(
56-
grad_out, indptr, grad_out.new_empty(src_size), 'add')
56+
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
5757

5858
return grad_src, None, None
5959

0 commit comments

Comments
 (0)