Skip to content

Commit 50a5dae

Browse files
committed
scatter min
1 parent d367c0b commit 50a5dae

File tree

16 files changed

+274
-899
lines changed

16 files changed

+274
-899
lines changed

test/backward.json

Lines changed: 0 additions & 56 deletions
This file was deleted.

test/forward.json

Lines changed: 0 additions & 118 deletions
This file was deleted.

test/test_backward.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.autograd import gradcheck
66
import torch_scatter
77

8-
from .utils import devices
8+
from .utils import dtypes, devices, tensor
99

1010
funcs = ['add', 'sub', 'mul', 'div', 'mean']
1111
indices = [2, 0, 1, 1, 0]
@@ -20,3 +20,35 @@ def test_backward(func, device):
2020
op = getattr(torch_scatter, 'scatter_{}'.format(func))
2121
data = (src, index)
2222
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
23+
24+
25+
tests = [{
26+
'name': 'max',
27+
'src': [1, 2, 3, 4, 5],
28+
'index': [2, 0, 1, 1, 0],
29+
'dim': 0,
30+
'fill_value': 0,
31+
'grad': [4, 8, 6],
32+
'expected': [6, 0, 0, 8, 4]
33+
}, {
34+
'name': 'min',
35+
'src': [1, 2, 3, 4, 5],
36+
'index': [2, 0, 1, 1, 0],
37+
'dim': 0,
38+
'fill_value': 3,
39+
'grad': [4, 8, 6],
40+
'expected': [6, 4, 8, 0, 0]
41+
}]
42+
43+
44+
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
45+
def test_arg_backward(test, dtype, device):
46+
src = tensor(test['src'], dtype, device)
47+
src.requires_grad_()
48+
index = tensor(test['index'], torch.long, device)
49+
grad = tensor(test['grad'], dtype, device)
50+
51+
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
52+
out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
53+
out.backward(grad)
54+
assert src.grad.tolist() == test['expected']

test/test_forward.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,38 @@
7676
'dim': 0,
7777
'fill_value': 0,
7878
'expected': [[3, 2.5], [3, 4]]
79+
}, {
80+
'name': 'max',
81+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
82+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
83+
'dim': -1,
84+
'fill_value': 0,
85+
'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
86+
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
87+
}, {
88+
'name': 'max',
89+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
90+
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
91+
'dim': 0,
92+
'fill_value': 0,
93+
'expected': [[5, 3], [4, 5]],
94+
'expected_arg': [[0, 3], [2, 1]]
95+
}, {
96+
'name': 'min',
97+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
98+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
99+
'dim': -1,
100+
'fill_value': 9,
101+
'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
102+
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
103+
}, {
104+
'name': 'min',
105+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
106+
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
107+
'dim': 0,
108+
'fill_value': 9,
109+
'expected': [[1, 2], [2, 3]],
110+
'expected_arg': [[3, 0], [1, 2]]
79111
}]
80112

81113

@@ -86,6 +118,10 @@ def test_forward(test, dtype, device):
86118
expected = tensor(test['expected'], dtype, device)
87119

88120
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
89-
output = op(src, index, test['dim'], fill_value=test['fill_value'])
121+
out = op(src, index, test['dim'], fill_value=test['fill_value'])
90122

91-
assert output.tolist() == expected.tolist()
123+
if isinstance(out, tuple):
124+
assert out[0].tolist() == expected.tolist()
125+
assert out[1].tolist() == test['expected_arg']
126+
else:
127+
assert out.tolist() == expected.tolist()

torch_scatter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from .mul import scatter_mul
44
from .div import scatter_div
55
from .mean import scatter_mean
6+
from .max import scatter_max
7+
from .min import scatter_min
68

79
__version__ = '1.0.0'
810

911
__all__ = [
1012
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean',
11-
'__version__'
13+
'scatter_max', 'scatter_min', '__version__'
1214
]

torch_scatter/functions/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)