@@ -180,6 +180,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
180180 arg_out_data = arg_out.value ().DATA_PTR <int64_t >();
181181 }
182182
183+ auto E = index.numel ();
183184 auto E_1 = index.numel () / src.size (reduce_dim);
184185 auto E_2 = src.size (reduce_dim);
185186 auto K = src.numel () / index.numel ();
@@ -191,41 +192,48 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
191192 auto src_data = src.DATA_PTR <scalar_t >();
192193 auto out_data = out.DATA_PTR <scalar_t >();
193194
194- scalar_t val ;
195- int64_t idx, next_idx, row_start, arg ;
195+ scalar_t vals[K] ;
196+ int64_t idx, next_idx, row_start, args[K] ;
196197 AT_DISPATCH_REDUCTION_TYPES (reduce, [&] {
197198 for (int e_1 = 0 ; e_1 < E_1; e_1++) {
198199 int offset = IndexToOffset<int64_t >::get (e_1 * E_2, index_info);
200+ idx = index_info.data [offset];
201+ row_start = 0 ;
199202
200203 for (int k = 0 ; k < K; k++) {
201- idx = index_info.data [offset];
202- row_start = 0 ;
203- val = out_data[e_1 * N * K + k];
204+ vals[k] = out_data[e_1 * N * K + k];
205+ }
204206
205- for (int e_2 = 0 ; e_2 < E_2; e_2++) {
207+ for (int e_2 = 0 ; e_2 < E_2; e_2++) {
208+
209+ for (int k = 0 ; k < K; k++) {
206210 Reducer<scalar_t , REDUCE>::update (
207- &val, src_data[e_1 * E_2 * K + e_2 * K + k], &arg, e_2);
211+ &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
212+ }
208213
209- if (e_2 == E_2 - 1 ) {
214+ if (e_2 == E_2 - 1 ) {
215+ for (int k = 0 ; k < K; k++) {
210216 Reducer<scalar_t , REDUCE>::write (
211- out_data + e_1 * N * K + idx * K + k, val ,
212- arg_out_data + e_1 * N * K + idx * K + k, arg ,
217+ out_data + e_1 * N * K + idx * K + k, vals[k] ,
218+ arg_out_data + e_1 * N * K + idx * K + k, args[k] ,
213219 e_2 + 1 - row_start);
214- } else {
215- next_idx = index_info.data [offset + (e_2 + 1 ) * stride];
220+ }
221+ } else {
222+ next_idx = index_info.data [offset + (e_2 + 1 ) * stride];
216223
217- if (idx != next_idx) {
224+ if (idx != next_idx) {
225+ for (int k = 0 ; k < K; k++) {
218226 Reducer<scalar_t , REDUCE>::write (
219- out_data + e_1 * N * K + idx * K + k, val ,
220- arg_out_data + e_1 * N * K + idx * K + k, arg ,
227+ out_data + e_1 * N * K + idx * K + k, vals[k] ,
228+ arg_out_data + e_1 * N * K + idx * K + k, args[k] ,
221229 e_2 + 1 - row_start);
222230
223- row_start = e_2 + 1 ;
224- val = out_data[e_1 * N * K + next_idx * K + k];
231+ vals[k] = out_data[e_1 * N * K + next_idx * K + k];
225232 }
226-
227- idx = next_idx;
233+ row_start = e_2 + 1 ;
228234 }
235+
236+ idx = next_idx;
229237 }
230238 }
231239 }
0 commit comments