@@ -14,34 +14,34 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
1414 [&] { \
1515 switch (reduce2REDUCE.at (reduce)) { \
1616 case SUM: { \
17- const ReductionType REDUCE = SUM; \
17+ static constexpr ReductionType REDUCE = SUM; \
1818 return __VA_ARGS__ (); \
1919 } \
2020 case MEAN: { \
21- const ReductionType REDUCE = MEAN; \
21+ static constexpr ReductionType REDUCE = MEAN; \
2222 return __VA_ARGS__ (); \
2323 } \
2424 case MUL: { \
25- const ReductionType REDUCE = MUL; \
25+ static constexpr ReductionType REDUCE = MUL; \
2626 return __VA_ARGS__ (); \
2727 } \
2828 case DIV: { \
29- const ReductionType REDUCE = DIV; \
29+ static constexpr ReductionType REDUCE = DIV; \
3030 return __VA_ARGS__ (); \
3131 } \
3232 case MIN: { \
33- const ReductionType REDUCE = MIN; \
33+ static constexpr ReductionType REDUCE = MIN; \
3434 return __VA_ARGS__ (); \
3535 } \
3636 case MAX: { \
37- const ReductionType REDUCE = MAX; \
37+ static constexpr ReductionType REDUCE = MAX; \
3838 return __VA_ARGS__ (); \
3939 } \
4040 } \
4141 }()
4242
43- template <typename scalar_t > struct Reducer {
44- static inline scalar_t init (ReductionType REDUCE ) {
43+ template <typename scalar_t , ReductionType REDUCE > struct Reducer {
44+ static inline scalar_t init () {
4545 if (REDUCE == MUL || REDUCE == DIV)
4646 return (scalar_t )1 ;
4747 else if (REDUCE == MIN)
@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer {
5252 return (scalar_t )0 ;
5353 }
5454
55- static inline void update (ReductionType REDUCE , scalar_t *val ,
56- scalar_t new_val, int64_t *arg, int64_t new_arg) {
55+ static inline void update (scalar_t *val , scalar_t new_val, int64_t *arg ,
56+ int64_t new_arg) {
5757 if (REDUCE == SUM || REDUCE == MEAN)
5858 *val = *val + new_val;
5959 else if (REDUCE == MUL)
@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer {
6767 }
6868 }
6969
70- static inline void write (ReductionType REDUCE, scalar_t *address,
71- scalar_t val, int64_t *arg_address, int64_t arg,
72- int count) {
70+ static inline void write (scalar_t *address, scalar_t val,
71+ int64_t *arg_address, int64_t arg, int count) {
7372 if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
7473 *address = val;
7574 else if (REDUCE == MEAN)
0 commit comments