@@ -33,9 +33,9 @@ inline bool is_outer_reduction(const int64_t* strides) {
3333 strides[3 ] == sizeof (typename traits::arg2_t );
3434}
3535
36- template <typename func_t , typename vec_func_t >
36+ template <typename func_t , typename vec_func_t , bool reduce >
3737inline void vectorized_reduction (char ** data, int64_t n, int64_t stride,
38- func_t op, vec_func_t vop, bool reduce ) {
38+ func_t op [[maybe_unused]] , vec_func_t vop) {
3939 VEC_LOOP_HEADER (func_t , data)
4040 const char * in1_ptr = data[1 ];
4141 Vec acc[4 ];
@@ -49,7 +49,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
4949 acc[2 ] = vop (acc[2 ], Vec::loadu (ptr + (2 * Vec::size () * sizeof (scalar_t ))));
5050 acc[3 ] = vop (acc[3 ], Vec::loadu (ptr + (3 * Vec::size () * sizeof (scalar_t ))));
5151 }
52- if (reduce) {
52+ if constexpr (reduce) {
5353 scalar_t buffer[Vec::size ()];
5454 acc[0 ] = vop (vop (acc[0 ], acc[1 ]), vop (acc[2 ], acc[3 ]));
5555 acc[0 ].store (buffer);
@@ -83,7 +83,7 @@ inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_fu
8383 constexpr int64_t vector_stride = 4 * Vec::size () * sizeof (scalar_t );
8484 int64_t count = n / (4 * Vec::size ());
8585 if (count > 0 ) {
86- vectorized_reduction (data, count, vector_stride, op, vop, /* reduce= */ true );
86+ vectorized_reduction< func_t , vec_func_t , true > (data, count, vector_stride, op, vop);
8787 }
8888 char * ptrs[3 ] = { data[0 ], data[0 ], data[1 ] };
8989 int64_t strides[] = { 0 , 0 , sizeof (scalar_t ) };
@@ -99,7 +99,7 @@ inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_
9999 constexpr int64_t vector_stride = 4 * Vec::size () * sizeof (scalar_t );
100100 int64_t outer_stride[2 ] = { vector_stride, vector_stride };
101101 UNARY_OUTER_LOOP (data, outer_stride, size1 / (4 * Vec::size ()), [&] {
102- vectorized_reduction (data, size0, inner_stride, op, vop, /* reduce= */ false );
102+ vectorized_reduction< func_t , vec_func_t , false > (data, size0, inner_stride, op, vop);
103103 });
104104
105105 // reduce down the remaining columns
0 commit comments