Skip to content

Commit d3aabdf

Browse files
committed
fix negative dim in scatter_mean
1 parent ff3be8e commit d3aabdf

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

csrc/scatter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
7171
Variable index, int64_t dim,
7272
torch::optional<Variable> optional_out,
7373
torch::optional<int64_t> dim_size) {
74+
dim = dim < 0 ? src.dim() + dim : dim;
7475
ctx->saved_data["dim"] = dim;
7576
ctx->saved_data["src_shape"] = src.sizes();
7677

test/test_broadcasting.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
1+
from itertools import product
2+
13
import pytest
24
import torch
3-
from torch_scatter import scatter_add
5+
from torch_scatter import scatter
46

5-
from .utils import devices
7+
from .utils import reductions, devices
68

79

8-
@pytest.mark.parametrize('device', devices)
9-
def test_broadcasting(device):
10+
@pytest.mark.parametrize('reduce,device', product(reductions, devices))
11+
def test_broadcasting(reduce, device):
1012
B, C, H, W = (4, 3, 8, 8)
1113

14+
src = torch.randn((B, C, H, W), device=device)
15+
index = torch.randint(0, H, (H, )).to(device, torch.long)
16+
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
17+
assert out.size() == (B, C, H, W)
18+
1219
src = torch.randn((B, C, H, W), device=device)
1320
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
14-
out = scatter_add(src, index, dim=2, dim_size=H)
21+
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
1522
assert out.size() == (B, C, H, W)
1623

1724
src = torch.randn((B, C, H, W), device=device)
1825
index = torch.randint(0, H, (H, )).to(device, torch.long)
19-
out = scatter_add(src, index, dim=2, dim_size=H)
26+
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
2027
assert out.size() == (B, C, H, W)

0 commit comments

Comments
 (0)