Skip to content

Commit 4a1164b

Browse files
committed
template reducer
1 parent ea9d68b commit 4a1164b

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

csrc/cpu/reducer.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,34 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
1414
[&] { \
1515
switch (reduce2REDUCE.at(reduce)) { \
1616
case SUM: { \
17-
const ReductionType REDUCE = SUM; \
17+
static constexpr ReductionType REDUCE = SUM; \
1818
return __VA_ARGS__(); \
1919
} \
2020
case MEAN: { \
21-
const ReductionType REDUCE = MEAN; \
21+
static constexpr ReductionType REDUCE = MEAN; \
2222
return __VA_ARGS__(); \
2323
} \
2424
case MUL: { \
25-
const ReductionType REDUCE = MUL; \
25+
static constexpr ReductionType REDUCE = MUL; \
2626
return __VA_ARGS__(); \
2727
} \
2828
case DIV: { \
29-
const ReductionType REDUCE = DIV; \
29+
static constexpr ReductionType REDUCE = DIV; \
3030
return __VA_ARGS__(); \
3131
} \
3232
case MIN: { \
33-
const ReductionType REDUCE = MIN; \
33+
static constexpr ReductionType REDUCE = MIN; \
3434
return __VA_ARGS__(); \
3535
} \
3636
case MAX: { \
37-
const ReductionType REDUCE = MAX; \
37+
static constexpr ReductionType REDUCE = MAX; \
3838
return __VA_ARGS__(); \
3939
} \
4040
} \
4141
}()
4242

43-
template <typename scalar_t> struct Reducer {
44-
static inline scalar_t init(ReductionType REDUCE) {
43+
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
44+
static inline scalar_t init() {
4545
if (REDUCE == MUL || REDUCE == DIV)
4646
return (scalar_t)1;
4747
else if (REDUCE == MIN)
@@ -52,8 +52,8 @@ template <typename scalar_t> struct Reducer {
5252
return (scalar_t)0;
5353
}
5454

55-
static inline void update(ReductionType REDUCE, scalar_t *val,
56-
scalar_t new_val, int64_t *arg, int64_t new_arg) {
55+
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
56+
int64_t new_arg) {
5757
if (REDUCE == SUM || REDUCE == MEAN)
5858
*val = *val + new_val;
5959
else if (REDUCE == MUL)
@@ -67,9 +67,8 @@ template <typename scalar_t> struct Reducer {
6767
}
6868
}
6969

70-
static inline void write(ReductionType REDUCE, scalar_t *address,
71-
scalar_t val, int64_t *arg_address, int64_t arg,
72-
int count) {
70+
static inline void write(scalar_t *address, scalar_t val,
71+
int64_t *arg_address, int64_t arg, int count) {
7372
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
7473
*address = val;
7574
else if (REDUCE == MEAN)

csrc/cpu/spmm_cpu.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
6363
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
6464

6565
for (auto k = 0; k < K; k++)
66-
vals[k] = Reducer<scalar_t>::init(REDUCE);
66+
vals[k] = Reducer<scalar_t, REDUCE>::init();
6767

6868
auto offset = b * N * K;
6969
for (auto e = row_start; e < row_end; e++) {
@@ -72,20 +72,19 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
7272
val = value_data[e];
7373
for (auto k = 0; k < K; k++) {
7474
if (HAS_VALUE)
75-
Reducer<scalar_t>::update(REDUCE, &vals[k],
76-
val * mat_data[offset + c * K + k],
77-
&args[k], e);
75+
Reducer<scalar_t, REDUCE>::update(
76+
&vals[k], val * mat_data[offset + c * K + k], &args[k],
77+
e);
7878
else
79-
Reducer<scalar_t>::update(REDUCE, &vals[k],
80-
mat_data[offset + c * K + k],
81-
&args[k], e);
79+
Reducer<scalar_t, REDUCE>::update(
80+
&vals[k], mat_data[offset + c * K + k], &args[k], e);
8281
}
8382
}
8483
offset = b * M * K + m * K;
8584
for (auto k = 0; k < K; k++)
86-
Reducer<scalar_t>::write(REDUCE, out_data + offset + k, vals[k],
87-
arg_out_data + offset + k, args[k],
88-
row_end - row_start);
85+
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
86+
arg_out_data + offset + k,
87+
args[k], row_end - row_start);
8988
}
9089
}
9190
});

0 commit comments

Comments
 (0)