Skip to content

Commit bb47653

Browse files
committed
added mean tests
1 parent a7beaca commit bb47653

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

test/forward.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,22 @@
6363
"fill_value": 1,
6464
"expected": [[0.25, 0.25], [0.125, 0.5]]
6565
},
66+
{
67+
"name": "mean",
68+
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
69+
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
70+
"dim": 1,
71+
"fill_value": 0,
72+
"expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
73+
},
74+
{
75+
"name": "mean",
76+
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
77+
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
78+
"dim": 0,
79+
"fill_value": 0,
80+
"expected": [[3, 2.5], [3, 4]]
81+
},
6682
{
6783
"name": "max",
6884
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],

torch_scatter/src/generic/cpu.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
2626
int64_t i, idx;
2727
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
2828
for (i = 0; i < THLongTensor_size(index, dim); i++) {
29-
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
29+
idx = *(index_data + i * index_stride);
30+
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
3031
output_data[idx * output_stride] += *(input_data + i * input_stride);
31-
output_data[idx * count_stride]++;
32+
count_data[idx * count_stride]++;
3233
})
3334
}
3435

0 commit comments

Comments
 (0)