Skip to content

Commit a7beaca

Browse files
committed
bugfix with tensor strides
1 parent aeca775 commit a7beaca

File tree

5 files changed

+103
-29
lines changed

5 files changed

+103
-29
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import build # noqa
66

7-
__version__ = '0.1.3'
7+
__version__ = '0.2.0'
88
url = 'https://github.com/rusty1s/pytorch_scatter'
99

1010
install_requires = ['cffi']

test/backward.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
99
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]]
1010
},
11+
{
12+
"name": "add",
13+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
14+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
15+
"dim": 0,
16+
"fill_value": 0,
17+
"grad": [[10, 20], [15, 25]],
18+
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]]
19+
},
1120
{
1221
"name": "mean",
1322
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
@@ -17,6 +26,15 @@
1726
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
1827
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]]
1928
},
29+
{
30+
"name": "mean",
31+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
32+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
33+
"dim": 0,
34+
"fill_value": 0,
35+
"grad": [[10, 20], [15, 25]],
36+
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]]
37+
},
2038
{
2139
"name": "max",
2240
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
@@ -25,5 +43,14 @@
2543
"fill_value": 0,
2644
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
2745
"expected": [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
46+
},
47+
{
48+
"name": "max",
49+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
50+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
51+
"dim": 0,
52+
"fill_value": 0,
53+
"grad": [[10, 20], [15, 25]],
54+
"expected": [[10, 0], [0, 25], [15, 0], [0, 20]]
2855
}
2956
]

test/forward.json

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
"fill_value": 0,
88
"expected": [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
99
},
10+
{
11+
"name": "add",
12+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
13+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
14+
"dim": 0,
15+
"fill_value": 0,
16+
"expected": [[6, 5], [6, 8]]
17+
},
1018
{
1119
"name": "sub",
1220
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
@@ -15,6 +23,14 @@
1523
"fill_value": 9,
1624
"expected": [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]]
1725
},
26+
{
27+
"name": "sub",
28+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
29+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
30+
"dim": 0,
31+
"fill_value": 9,
32+
"expected": [[3, 4], [3, 1]]
33+
},
1834
{
1935
"name": "mul",
2036
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
@@ -23,6 +39,14 @@
2339
"fill_value": 1,
2440
"expected": [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]]
2541
},
42+
{
43+
"name": "mul",
44+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
45+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
46+
"dim": 0,
47+
"fill_value": 1,
48+
"expected": [[5, 6], [8, 15]]
49+
},
2650
{
2751
"name": "div",
2852
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
@@ -32,12 +56,12 @@
3256
"expected": [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]]
3357
},
3458
{
35-
"name": "mean",
36-
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
37-
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
38-
"dim": 1,
39-
"fill_value": 0,
40-
"expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
59+
"name": "div",
60+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
61+
"input": [[4, 2], [2, 1], [4, 2], [1, 2]],
62+
"dim": 0,
63+
"fill_value": 1,
64+
"expected": [[0.25, 0.25], [0.125, 0.5]]
4165
},
4266
{
4367
"name": "max",
@@ -48,6 +72,15 @@
4872
"expected": [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
4973
"expected_arg": [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
5074
},
75+
{
76+
"name": "max",
77+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
78+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
79+
"dim": 0,
80+
"fill_value": 0,
81+
"expected": [[5, 3], [4, 5]],
82+
"expected_arg": [[0, 3], [2, 1]]
83+
},
5184
{
5285
"name": "min",
5386
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
@@ -56,5 +89,14 @@
5689
"fill_value": 9,
5790
"expected": [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
5891
"expected_arg": [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
92+
},
93+
{
94+
"name": "min",
95+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
96+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
97+
"dim": 0,
98+
"fill_value": 9,
99+
"expected": [[1, 2], [2, 3]],
100+
"expected_arg": [[3, 0], [1, 2]]
59101
}
60102
]

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .functions.max import scatter_max_, scatter_max
77
from .functions.min import scatter_min_, scatter_min
88

9-
__version__ = '0.1.3'
9+
__version__ = '0.2.0'
1010

1111
__all__ = [
1212
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',

torch_scatter/src/generic/cpu.c

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,67 @@
33
#else
44

55
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
6-
int64_t i;
6+
int64_t i, idx;
77
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
88
for (i = 0; i < THLongTensor_size(index, dim); i++) {
9-
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
10-
output_data[index_data[i]] *= input_data[i];
9+
idx = *(index_data + i * index_stride);
10+
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
11+
output_data[idx * output_stride] *= *(input_data + i * input_stride);
1112
})
1213
}
1314

1415
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
15-
int64_t i;
16+
int64_t i, idx;
1617
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
1718
for (i = 0; i < THLongTensor_size(index, dim); i++) {
18-
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
19-
output_data[index_data[i]] /= input_data[i];
19+
idx = *(index_data + i * index_stride);
20+
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
21+
output_data[idx * output_stride] /= *(input_data + i * input_stride);
2022
})
2123
}
2224

2325
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *count) {
24-
int64_t i;
26+
int64_t i, idx;
2527
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
2628
for (i = 0; i < THLongTensor_size(index, dim); i++) {
2729
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
28-
output_data[index_data[i]] += input_data[i];
29-
count_data[index_data[i]]++;
30+
output_data[idx * output_stride] += *(input_data + i * input_stride);
31+
output_data[idx * count_stride]++;
3032
})
3133
}
3234

3335
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
34-
int64_t i;
36+
int64_t i, idx;
3537
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
3638
for (i = 0; i < THLongTensor_size(index, dim); i++) {
37-
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
38-
if (input_data[i] >= output_data[index_data[i]]) {
39-
output_data[index_data[i]] = input_data[i];
40-
arg_data[index_data[i]] = i;
39+
idx = *(index_data + i * index_stride);
40+
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
41+
if (*(input_data + i * input_stride) >= *(output_data + idx * output_stride)) {
42+
output_data[idx * output_stride] = *(input_data + i * input_stride);
43+
arg_data[idx * arg_stride] = i;
4144
}
4245
})
4346
}
4447

4548
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
46-
int64_t i;
49+
int64_t i, idx;
4750
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
4851
for (i = 0; i < THLongTensor_size(index, dim); i++) {
49-
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
50-
if (input_data[i] <= output_data[index_data[i]]) {
51-
output_data[index_data[i]] = input_data[i];
52-
arg_data[index_data[i]] = i;
52+
idx = *(index_data + i * index_stride);
53+
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
54+
if (*(input_data + i * input_stride) <= *(output_data + idx * output_stride)) {
55+
output_data[idx * output_stride] = *(input_data + i * input_stride);
56+
arg_data[idx * arg_stride] = i;
5357
}
5458
})
5559
}
5660

5761
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *arg) {
58-
int64_t i;
62+
int64_t i, idx;
5963
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, arg, dim,
6064
for (i = 0; i < THLongTensor_size(index, dim); i++) {
61-
if (arg_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
65+
idx = *(index_data + i * index_stride);
66+
if (*(arg_data + idx * arg_stride) == i) output_data[i * output_stride] = *(grad_data + idx * grad_stride);
6267
})
6368
}
6469

0 commit comments

Comments
 (0)