@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
1111
1212#define AT_DISPATCH_REDUCTION_TYPES (reduce, ...) \
1313 [&] { \
14+ ReductionType REDUCE = ADD; \
1415 if (reduce == " add" ) { \
15- const ReductionType REDUCE = ADD; \
16+ REDUCE = ADD; \
1617 return __VA_ARGS__ (); \
1718 } else if (reduce == " mean" ) { \
18- const ReductionType REDUCE = MEAN; \
19+ REDUCE = MEAN; \
1920 return __VA_ARGS__ (); \
2021 } else if (reduce == " min" ) { \
21- const ReductionType REDUCE = MIN; \
22+ REDUCE = MIN; \
2223 return __VA_ARGS__ (); \
2324 } else if (reduce == " max" ) { \
24- const ReductionType REDUCE = MAX; \
25+ REDUCE = MAX; \
2526 return __VA_ARGS__ (); \
2627 } \
2728 }()
2829
29- template <typename scalar_t , ReductionType REDUCE > struct Reducer {
30- static inline scalar_t init () {
30+ template <typename scalar_t > struct Reducer {
31+ static inline scalar_t init (ReductionType REDUCE ) {
3132 if (REDUCE == MIN) {
3233 return std::numeric_limits<scalar_t >::max ();
3334 } else if (REDUCE == MAX) {
@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
3738 }
3839 }
3940
40- static inline void update (scalar_t *val, scalar_t new_val) {
41+ static inline void update (ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
4142 if (REDUCE == ADD || REDUCE == MEAN) {
4243 *val = *val + new_val;
4344 } else if ((REDUCE == MIN && new_val < *val) ||
@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
4647 }
4748 }
4849
49- static inline void update (scalar_t *val, scalar_t new_val, int64_t *arg,
50+ static inline void update (ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
5051 int64_t new_arg) {
5152 if (REDUCE == ADD || REDUCE == MEAN) {
5253 *val = *val + new_val;
@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
5758 }
5859 }
5960
60- static inline void write (scalar_t *address, scalar_t val,
61+ static inline void write (ReductionType REDUCE, scalar_t *address, scalar_t val,
6162 int64_t *arg_address, int64_t arg, int count) {
6263 if (REDUCE == ADD) {
6364 *address = val;
@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
136137
137138 offset = (n / (indptr.size (-1 ) - 1 )) * E * K;
138139 for (int k = 0 ; k < K; k++) {
139- vals[k] = Reducer<scalar_t , REDUCE >::init ();
140+ vals[k] = Reducer<scalar_t >::init (REDUCE );
140141 }
141142 for (int64_t e = row_start; e < row_end; e++) {
142143 for (int k = 0 ; k < K; k++) {
143- Reducer<scalar_t , REDUCE >::update (
144+ Reducer<scalar_t >::update (REDUCE,
144145 &vals[k], src_data[offset + e * K + k], &args[k], e);
145146 }
146147 }
147148 for (int k = 0 ; k < K; k++) {
148- Reducer<scalar_t , REDUCE >::write (out_data + n * K + k, vals[k],
149+ Reducer<scalar_t >::write (REDUCE, out_data + n * K + k, vals[k],
149150 arg_out_data + n * K + k, args[k],
150151 row_end - row_start);
151152 }
@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
214215 for (int e_2 = 0 ; e_2 < E_2; e_2++) {
215216
216217 for (int k = 0 ; k < K; k++) {
217- Reducer<scalar_t , REDUCE >::update (
218+ Reducer<scalar_t >::update (REDUCE,
218219 &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
219220 }
220221
221222 if (e_2 == E_2 - 1 ) {
222223 for (int k = 0 ; k < K; k++) {
223- Reducer<scalar_t , REDUCE >::write (
224+ Reducer<scalar_t >::write (REDUCE,
224225 out_data + e_1 * N * K + idx * K + k, vals[k],
225226 arg_out_data + e_1 * N * K + idx * K + k, args[k],
226227 e_2 + 1 - row_start);
@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
231232
232233 if (idx != next_idx) {
233234 for (int k = 0 ; k < K; k++) {
234- Reducer<scalar_t , REDUCE >::write (
235+ Reducer<scalar_t >::write (REDUCE,
235236 out_data + e_1 * N * K + idx * K + k, vals[k],
236237 arg_out_data + e_1 * N * K + idx * K + k, args[k],
237238 e_2 + 1 - row_start);
0 commit comments