Skip to content

Commit 8b8df25

Browse files
committed
new pytorch 0.4.0 format
1 parent 367b0af commit 8b8df25

File tree

8 files changed

+179
-233
lines changed

8 files changed

+179
-233
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '0.3.0'
5+
__version__ = '1.0.0'
66
url = 'https://github.com/rusty1s/pytorch_scatter'
77

88
install_requires = ['cffi']

test/test_backward.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,22 @@
1-
from os import path as osp
21
from itertools import product
32

43
import pytest
5-
import json
64
import torch
7-
from torch.autograd import Variable as V
5+
from torch.autograd import gradcheck
86
import torch_scatter
97

10-
from .utils import tensors, Tensor
8+
from .utils import devices
119

12-
f = open(osp.join(osp.dirname(__file__), 'backward.json'), 'r')
13-
data = json.load(f)
14-
f.close()
10+
funcs = ['add']
11+
indices = [2, 0, 1, 1, 0]
1512

1613

17-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
18-
def test_backward_cpu(tensor, i):
19-
name = data[i]['name']
20-
index = V(torch.LongTensor(data[i]['index']))
21-
input = V(Tensor(tensor, data[i]['input']), requires_grad=True)
22-
dim = data[i]['dim']
23-
fill_value = data[i]['fill_value']
24-
grad = Tensor(tensor, data[i]['grad'])
25-
output = V(grad.new(grad.size()).fill_(fill_value))
26-
expected = Tensor(tensor, data[i]['expected'])
14+
@pytest.mark.parametrize('func,device', product(funcs, devices))
15+
def test_backward(func, device):
16+
index = torch.tensor(indices, dtype=torch.long, device=device)
17+
src = torch.rand(index.size(), dtype=torch.double, device=device)
18+
src.requires_grad_()
2719

28-
func = getattr(torch_scatter, 'scatter_{}_'.format(name))
29-
func(output, index, input, dim)
30-
output.backward(grad)
31-
assert input.grad.data.tolist() == expected.tolist()
32-
33-
34-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
35-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
36-
def test_backward_gpu(tensor, i): # pragma: no cover
37-
name = data[i]['name']
38-
index = V(torch.cuda.LongTensor(data[i]['index']))
39-
input = V(Tensor(tensor, data[i]['input']).cuda(), requires_grad=True)
40-
dim = data[i]['dim']
41-
fill_value = data[i]['fill_value']
42-
grad = Tensor(tensor, data[i]['grad']).cuda()
43-
output = V(grad.new(grad.size()).fill_(fill_value).cuda())
44-
expected = Tensor(tensor, data[i]['expected'])
45-
46-
func = getattr(torch_scatter, 'scatter_{}_'.format(name))
47-
func(output, index, input, dim)
48-
output.backward(grad)
49-
assert input.grad.data.cpu().tolist() == expected.tolist()
20+
op = getattr(torch_scatter, 'scatter_{}'.format(func))
21+
data = (src, index)
22+
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True

test/test_forward.py

Lines changed: 43 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,49 @@
1-
from os import path as osp
21
from itertools import product
32

43
import pytest
5-
import json
64
import torch
75
import torch_scatter
86

9-
from .utils import tensors, Tensor
10-
11-
f = open(osp.join(osp.dirname(__file__), 'forward.json'), 'r')
12-
data = json.load(f)
13-
f.close()
14-
15-
16-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
17-
def test_forward_cpu(tensor, i):
18-
name = data[i]['name']
19-
index = torch.LongTensor(data[i]['index'])
20-
input = Tensor(tensor, data[i]['input'])
21-
dim = data[i]['dim']
22-
fill_value = data[i]['fill_value']
23-
expected = torch.FloatTensor(data[i]['expected']).type_as(input)
24-
output = expected.new(expected.size()).fill_(fill_value)
25-
26-
func = getattr(torch_scatter, 'scatter_{}_'.format(name))
27-
result = func(output, index, input, dim)
28-
assert output.tolist() == expected.tolist()
29-
if 'expected_arg' in data[i]:
30-
expected_arg = torch.LongTensor(data[i]['expected_arg'])
31-
assert result[1].tolist() == expected_arg.tolist()
32-
33-
func = getattr(torch_scatter, 'scatter_{}'.format(name))
34-
result = func(index, input, dim, fill_value=fill_value)
35-
if 'expected_arg' not in data[i]:
36-
assert result.tolist() == expected.tolist()
37-
else:
38-
expected_arg = torch.LongTensor(data[i]['expected_arg'])
39-
assert result[0].tolist() == expected.tolist()
40-
assert result[1].tolist() == expected_arg.tolist()
41-
42-
43-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
44-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
45-
def test_forward_gpu(tensor, i): # pragma: no cover
46-
name = data[i]['name']
47-
index = torch.cuda.LongTensor(data[i]['index'])
48-
input = Tensor(tensor, data[i]['input']).cuda()
49-
dim = data[i]['dim']
50-
fill_value = data[i]['fill_value']
51-
expected = torch.FloatTensor(data[i]['expected']).type_as(input)
52-
output = expected.new(expected.size()).fill_(fill_value).cuda()
53-
54-
func = getattr(torch_scatter, 'scatter_{}_'.format(name))
55-
result = func(output, index, input, dim)
56-
assert output.cpu().tolist() == expected.tolist()
57-
if 'expected_arg' in data[i]:
58-
expected_arg = torch.LongTensor(data[i]['expected_arg'])
59-
assert result[1].cpu().tolist() == expected_arg.tolist()
60-
func = getattr(torch_scatter, 'scatter_{}'.format(name))
61-
result = func(index, input, dim, fill_value=fill_value)
62-
if 'expected_arg' not in data[i]:
63-
assert result.cpu().tolist() == expected.tolist()
64-
else:
65-
expected_arg = torch.LongTensor(data[i]['expected_arg'])
66-
assert result[0].cpu().tolist() == expected.tolist()
67-
assert result[1].cpu().tolist() == expected_arg.tolist()
7+
from .utils import dtypes, devices, tensor
8+
9+
tests = [{
10+
'name': 'add',
11+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
12+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
13+
'fill_value': 0,
14+
'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
15+
}]
16+
17+
18+
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
19+
def test_forward(test, dtype, device):
20+
src = tensor(test['src'], dtype, device)
21+
index = tensor(test['index'], torch.long, device)
22+
23+
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
24+
output = op(src, index, fill_value=test['fill_value'])
25+
26+
assert output.tolist() == test['expected']
27+
# name = data[i]['name']
28+
# index = torch.LongTensor(data[i]['index'])
29+
# input = Tensor(tensor, data[i]['input'])
30+
# dim = data[i]['dim']
31+
# fill_value = data[i]['fill_value']
32+
# expected = torch.FloatTensor(data[i]['expected']).type_as(input)
33+
# output = expected.new(expected.size()).fill_(fill_value)
34+
35+
# func = getattr(torch_scatter, 'scatter_{}_'.format(name))
36+
# result = func(output, index, input, dim)
37+
# assert output.tolist() == expected.tolist()
38+
# if 'expected_arg' in data[i]:
39+
# expected_arg = torch.LongTensor(data[i]['expected_arg'])
40+
# assert result[1].tolist() == expected_arg.tolist()
41+
42+
# func = getattr(torch_scatter, 'scatter_{}'.format(name))
43+
# result = func(index, input, dim, fill_value=fill_value)
44+
# if 'expected_arg' not in data[i]:
45+
# assert result.tolist() == expected.tolist()
46+
# else:
47+
# expected_arg = torch.LongTensor(data[i]['expected_arg'])
48+
# assert result[0].tolist() == expected.tolist()
49+
# assert result[1].tolist() == expected_arg.tolist()

test/utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
2-
from torch._tensor_docs import tensor_classes
2+
from torch.testing import get_all_dtypes
33

4-
tensors = [t[:-4] for t in tensor_classes]
5-
tensors.remove('ShortTensor') # TODO: PyTorch `atomicAdd` bug with short type.
6-
tensors.remove('ByteTensor') # We cannot properly test unsigned values.
7-
tensors.remove('CharTensor') # Overflow on gradient computations :(
4+
dtypes = get_all_dtypes()
5+
dtypes.remove(torch.half)
6+
dtypes.remove(torch.short) # TODO: PyTorch `atomicAdd` bug with short type.
7+
dtypes.remove(torch.uint8) # We cannot properly test unsigned values.
8+
dtypes.remove(torch.int8) # Overflow on gradient computations :(
89

10+
devices = [torch.device('cpu')]
11+
if torch.cuda.is_available(): # pragma: no cover
12+
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
913

10-
def Tensor(str, x):
11-
tensor = getattr(torch, str)
12-
return tensor(x)
14+
15+
def tensor(x, dtype, device):
16+
return None if x is None else torch.tensor(x, dtype=dtype, device=device)

torch_scatter/__init__.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
1-
from .functions.add import scatter_add_, scatter_add
2-
from .functions.sub import scatter_sub_, scatter_sub
3-
from .functions.mul import scatter_mul_, scatter_mul
4-
from .functions.div import scatter_div_, scatter_div
5-
from .functions.mean import scatter_mean_, scatter_mean
6-
from .functions.max import scatter_max_, scatter_max
7-
from .functions.min import scatter_min_, scatter_min
1+
from .add import ScatterAdd, scatter_add
82

9-
__version__ = '0.3.0'
3+
__version__ = '1.0.0'
104

11-
__all__ = [
12-
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
13-
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
14-
'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
15-
'scatter_min_', 'scatter_min', '__version__'
16-
]
5+
__all__ = ['ScatterAdd', 'scatter_add', '__version__']

torch_scatter/add.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from torch.autograd import Function
2+
3+
from .utils.gen import gen
4+
5+
6+
class ScatterAdd(Function):
7+
@staticmethod
8+
def forward(ctx, out, src, index, dim=-1):
9+
ctx.mark_dirty(out)
10+
ctx.save_for_backward(index)
11+
return out.scatter_add_(dim, index, src)
12+
13+
@staticmethod
14+
def backward(ctx, grad_out):
15+
index, = ctx.saved_variables
16+
17+
grad_src = None
18+
if ctx.needs_input_grad[1]:
19+
grad_src = grad_out[index]
20+
21+
return None, grad_src, None, None
22+
23+
24+
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
25+
r"""
26+
|
27+
28+
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
29+
master/docs/source/_figures/add.svg?sanitize=true
30+
:align: center
31+
:width: 400px
32+
33+
|
34+
35+
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
36+
specified in the :attr:`index` tensor along an given axis :attr:`dim`. For
37+
each value in :attr:`src`, its output index is specified by its index in
38+
:attr:`input` for dimensions outside of :attr:`dim` and by the
39+
corresponding value in :attr:`index` for dimension :attr:`dim`. If
40+
multiple indices reference the same location, their **contributions add**.
41+
42+
Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
43+
size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
44+
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
45+
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
46+
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
47+
48+
For one-dimensional tensors, the operation computes
49+
50+
.. math::
51+
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
52+
53+
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
54+
55+
Args:
56+
src (Tensor): The source tensor.
57+
index (LongTensor): The indices of elements to scatter.
58+
dim (int, optional): The axis along which to index.
59+
(default: :obj:`-1`)
60+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
61+
dim_size (int, optional): If :attr:`out` is not given, automatically
62+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
63+
If :attr:`dim_size` is not given, a minimal sized output tensor is
64+
returned. (default: :obj:`None`)
65+
fill_value (int, optional): If :attr:`out` is not given, automatically
66+
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
67+
68+
:rtype: :class:`Tensor`
69+
70+
.. testsetup::
71+
72+
import torch
73+
74+
.. testcode::
75+
76+
from torch_scatter import scatter_add_
77+
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
78+
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
79+
out = src.new_zeros((2, 6))
80+
out = scatter_add(src, index, out=out)
81+
print(out)
82+
83+
.. testoutput::
84+
85+
0 0 4 3 3 0
86+
2 4 4 0 0 0
87+
[torch.FloatTensor of size 2x6]
88+
"""
89+
out, index = gen(src, index, dim, out, dim_size, fill_value)
90+
return ScatterAdd.apply(out, src, index, dim)

0 commit comments

Comments
 (0)