@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
123123 auto src_data = src.DATA_PTR <scalar_t >();
124124 auto out_data = out.DATA_PTR <scalar_t >();
125125
126- scalar_t val ;
127- int64_t row_start, row_end, arg ;
126+ scalar_t vals[K] ;
127+ int64_t row_start, row_end, args[K] ;
128128 AT_DISPATCH_REDUCTION_TYPES (reduce, [&] {
129129 for (int n = 0 ; n < N; n++) {
130130 int offset = IndexPtrToOffset<int64_t >::get (n, indptr_info);
@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
133133
134134 offset = (n / (indptr.size (-1 ) - 1 )) * E * K;
135135 for (int k = 0 ; k < K; k++) {
136- val = Reducer<scalar_t , REDUCE>::init ();
137- for (int64_t e = row_start; e < row_end; e++) {
136+ vals[k] = Reducer<scalar_t , REDUCE>::init ();
137+ }
138+ for (int64_t e = row_start; e < row_end; e++) {
139+ for (int k = 0 ; k < K; k++) {
138140 Reducer<scalar_t , REDUCE>::update (
139- &val , src_data[offset + e * K + k], &arg , e);
141+ &vals[k] , src_data[offset + e * K + k], &args[k] , e);
140142 }
141- Reducer<scalar_t , REDUCE>::write (out_data + n * K + k, val,
142- arg_out_data + n * K + k, arg,
143+ }
144+ for (int k = 0 ; k < K; k++) {
145+ Reducer<scalar_t , REDUCE>::write (out_data + n * K + k, vals[k],
146+ arg_out_data + n * K + k, args[k],
143147 row_end - row_start);
144148 }
145149 }
0 commit comments