Skip to content

Commit 3994f3a

Browse files
committed
faster segment coo cpu implementation
1 parent 4a5379c commit 3994f3a

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

cpu/segment.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
180180
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
181181
}
182182

183+
auto E = index.numel();
183184
auto E_1 = index.numel() / src.size(reduce_dim);
184185
auto E_2 = src.size(reduce_dim);
185186
auto K = src.numel() / index.numel();
@@ -191,41 +192,48 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
191192
auto src_data = src.DATA_PTR<scalar_t>();
192193
auto out_data = out.DATA_PTR<scalar_t>();
193194

194-
scalar_t val;
195-
int64_t idx, next_idx, row_start, arg;
195+
scalar_t vals[K];
196+
int64_t idx, next_idx, row_start, args[K];
196197
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
197198
for (int e_1 = 0; e_1 < E_1; e_1++) {
198199
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
200+
idx = index_info.data[offset];
201+
row_start = 0;
199202

200203
for (int k = 0; k < K; k++) {
201-
idx = index_info.data[offset];
202-
row_start = 0;
203-
val = out_data[e_1 * N * K + k];
204+
vals[k] = out_data[e_1 * N * K + k];
205+
}
204206

205-
for (int e_2 = 0; e_2 < E_2; e_2++) {
207+
for (int e_2 = 0; e_2 < E_2; e_2++) {
208+
209+
for (int k = 0; k < K; k++) {
206210
Reducer<scalar_t, REDUCE>::update(
207-
&val, src_data[e_1 * E_2 * K + e_2 * K + k], &arg, e_2);
211+
&vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
212+
}
208213

209-
if (e_2 == E_2 - 1) {
214+
if (e_2 == E_2 - 1) {
215+
for (int k = 0; k < K; k++) {
210216
Reducer<scalar_t, REDUCE>::write(
211-
out_data + e_1 * N * K + idx * K + k, val,
212-
arg_out_data + e_1 * N * K + idx * K + k, arg,
217+
out_data + e_1 * N * K + idx * K + k, vals[k],
218+
arg_out_data + e_1 * N * K + idx * K + k, args[k],
213219
e_2 + 1 - row_start);
214-
} else {
215-
next_idx = index_info.data[offset + (e_2 + 1) * stride];
220+
}
221+
} else {
222+
next_idx = index_info.data[offset + (e_2 + 1) * stride];
216223

217-
if (idx != next_idx) {
224+
if (idx != next_idx) {
225+
for (int k = 0; k < K; k++) {
218226
Reducer<scalar_t, REDUCE>::write(
219-
out_data + e_1 * N * K + idx * K + k, val,
220-
arg_out_data + e_1 * N * K + idx * K + k, arg,
227+
out_data + e_1 * N * K + idx * K + k, vals[k],
228+
arg_out_data + e_1 * N * K + idx * K + k, args[k],
221229
e_2 + 1 - row_start);
222230

223-
row_start = e_2 + 1;
224-
val = out_data[e_1 * N * K + next_idx * K + k];
231+
vals[k] = out_data[e_1 * N * K + next_idx * K + k];
225232
}
226-
227-
idx = next_idx;
233+
row_start = e_2 + 1;
228234
}
235+
236+
idx = next_idx;
229237
}
230238
}
231239
}

0 commit comments

Comments
 (0)