Skip to content

Commit 6733f0e

Browse files
authored
Merge pull request #100 from Novare/fix_segreducer
fix: fix errors regarding Reducer functionalities in segment.cpp
2 parents d88411f + cd84568 commit 6733f0e

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

cpu/segment.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
1111

1212
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
1313
[&] { \
14+
ReductionType REDUCE = ADD; \
1415
if (reduce == "add") { \
15-
const ReductionType REDUCE = ADD; \
16+
REDUCE = ADD; \
1617
return __VA_ARGS__(); \
1718
} else if (reduce == "mean") { \
18-
const ReductionType REDUCE = MEAN; \
19+
REDUCE = MEAN; \
1920
return __VA_ARGS__(); \
2021
} else if (reduce == "min") { \
21-
const ReductionType REDUCE = MIN; \
22+
REDUCE = MIN; \
2223
return __VA_ARGS__(); \
2324
} else if (reduce == "max") { \
24-
const ReductionType REDUCE = MAX; \
25+
REDUCE = MAX; \
2526
return __VA_ARGS__(); \
2627
} \
2728
}()
2829

29-
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
30-
static inline scalar_t init() {
30+
template <typename scalar_t> struct Reducer {
31+
static inline scalar_t init(ReductionType REDUCE) {
3132
if (REDUCE == MIN) {
3233
return std::numeric_limits<scalar_t>::max();
3334
} else if (REDUCE == MAX) {
@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
3738
}
3839
}
3940

40-
static inline void update(scalar_t *val, scalar_t new_val) {
41+
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
4142
if (REDUCE == ADD || REDUCE == MEAN) {
4243
*val = *val + new_val;
4344
} else if ((REDUCE == MIN && new_val < *val) ||
@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
4647
}
4748
}
4849

49-
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
50+
static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
5051
int64_t new_arg) {
5152
if (REDUCE == ADD || REDUCE == MEAN) {
5253
*val = *val + new_val;
@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
5758
}
5859
}
5960

60-
static inline void write(scalar_t *address, scalar_t val,
61+
static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
6162
int64_t *arg_address, int64_t arg, int count) {
6263
if (REDUCE == ADD) {
6364
*address = val;
@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
136137

137138
offset = (n / (indptr.size(-1) - 1)) * E * K;
138139
for (int k = 0; k < K; k++) {
139-
vals[k] = Reducer<scalar_t, REDUCE>::init();
140+
vals[k] = Reducer<scalar_t>::init(REDUCE);
140141
}
141142
for (int64_t e = row_start; e < row_end; e++) {
142143
for (int k = 0; k < K; k++) {
143-
Reducer<scalar_t, REDUCE>::update(
144+
Reducer<scalar_t>::update(REDUCE,
144145
&vals[k], src_data[offset + e * K + k], &args[k], e);
145146
}
146147
}
147148
for (int k = 0; k < K; k++) {
148-
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
149+
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
149150
arg_out_data + n * K + k, args[k],
150151
row_end - row_start);
151152
}
@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
214215
for (int e_2 = 0; e_2 < E_2; e_2++) {
215216

216217
for (int k = 0; k < K; k++) {
217-
Reducer<scalar_t, REDUCE>::update(
218+
Reducer<scalar_t>::update(REDUCE,
218219
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
219220
}
220221

221222
if (e_2 == E_2 - 1) {
222223
for (int k = 0; k < K; k++) {
223-
Reducer<scalar_t, REDUCE>::write(
224+
Reducer<scalar_t>::write(REDUCE,
224225
out_data + e_1 * N * K + idx * K + k, vals[k],
225226
arg_out_data + e_1 * N * K + idx * K + k, args[k],
226227
e_2 + 1 - row_start);
@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
231232

232233
if (idx != next_idx) {
233234
for (int k = 0; k < K; k++) {
234-
Reducer<scalar_t, REDUCE>::write(
235+
Reducer<scalar_t>::write(REDUCE,
235236
out_data + e_1 * N * K + idx * K + k, vals[k],
236237
arg_out_data + e_1 * N * K + idx * K + k, args[k],
237238
e_2 + 1 - row_start);

0 commit comments

Comments
 (0)