Skip to content

Commit ba26dfb

Browse files
committed
faster
1 parent bb47653 commit ba26dfb

File tree

1 file changed

+18
-12
lines changed
  • torch_scatter/src/generic

1 file changed

+18
-12
lines changed

torch_scatter/src/generic/cpu.c

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,32 @@
33
#else
44

55
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
6-
int64_t i, idx;
6+
int64_t n, i, idx;
7+
n = THLongTensor_size(index, dim);
78
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
8-
for (i = 0; i < THLongTensor_size(index, dim); i++) {
9+
for (i = 0; i < n; i++) {
910
idx = *(index_data + i * index_stride);
1011
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
1112
output_data[idx * output_stride] *= *(input_data + i * input_stride);
1213
})
1314
}
1415

1516
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
16-
int64_t i, idx;
17+
int64_t n, i, idx;
18+
n = THLongTensor_size(index, dim);
1719
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
18-
for (i = 0; i < THLongTensor_size(index, dim); i++) {
20+
for (i = 0; i < n; i++) {
1921
idx = *(index_data + i * index_stride);
2022
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
2123
output_data[idx * output_stride] /= *(input_data + i * input_stride);
2224
})
2325
}
2426

2527
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *count) {
26-
int64_t i, idx;
28+
int64_t n, i, idx;
29+
n = THLongTensor_size(index, dim);
2730
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
28-
for (i = 0; i < THLongTensor_size(index, dim); i++) {
31+
for (i = 0; i < n; i++) {
2932
idx = *(index_data + i * index_stride);
3033
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
3134
output_data[idx * output_stride] += *(input_data + i * input_stride);
@@ -34,9 +37,10 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
3437
}
3538

3639
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
37-
int64_t i, idx;
40+
int64_t n, i, idx;
41+
n = THLongTensor_size(index, dim);
3842
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
39-
for (i = 0; i < THLongTensor_size(index, dim); i++) {
43+
for (i = 0; i < n; i++) {
4044
idx = *(index_data + i * index_stride);
4145
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
4246
if (*(input_data + i * input_stride) >= *(output_data + idx * output_stride)) {
@@ -47,9 +51,10 @@ void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
4751
}
4852

4953
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
50-
int64_t i, idx;
54+
int64_t n, i, idx;
55+
n = THLongTensor_size(index, dim);
5156
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
52-
for (i = 0; i < THLongTensor_size(index, dim); i++) {
57+
for (i = 0; i < n; i++) {
5358
idx = *(index_data + i * index_stride);
5459
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
5560
if (*(input_data + i * input_stride) <= *(output_data + idx * output_stride)) {
@@ -60,9 +65,10 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
6065
}
6166

6267
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *arg) {
63-
int64_t i, idx;
68+
int64_t n, i, idx;
69+
n = THLongTensor_size(index, dim);
6470
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, arg, dim,
65-
for (i = 0; i < THLongTensor_size(index, dim); i++) {
71+
for (i = 0; i < n; i++) {
6672
idx = *(index_data + i * index_stride);
6773
if (*(arg_data + idx * arg_stride) == i) output_data[i * output_stride] = *(grad_data + idx * grad_stride);
6874
})

0 commit comments

Comments
 (0)