Skip to content

Commit 0b71aad

Browse files
committed
scatter mean
1 parent f25c0e7 commit 0b71aad

File tree

8 files changed

+125
-6
lines changed

8 files changed

+125
-6
lines changed

test/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .utils import devices
99

10-
funcs = ['add', 'sub']
10+
funcs = ['add', 'sub', 'mean']
1111
indices = [2, 0, 1, 1, 0]
1212

1313

test/test_forward.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,30 @@
3434
'dim': 0,
3535
'fill_value': 9,
3636
'expected': [[3, 4], [3, 5]]
37+
}, {
38+
'name': 'mean',
39+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
40+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
41+
'dim': 1,
42+
'fill_value': 0,
43+
'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
44+
}, {
45+
'name': 'mean',
46+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
47+
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
48+
'dim': 0,
49+
'fill_value': 0,
50+
'expected': [[3, 2.5], [3, 4]]
3751
}]
3852

3953

4054
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
4155
def test_forward(test, dtype, device):
4256
src = tensor(test['src'], dtype, device)
4357
index = tensor(test['index'], torch.long, device)
58+
expected = tensor(test['expected'], dtype, device)
4459

4560
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
4661
output = op(src, index, test['dim'], fill_value=test['fill_value'])
4762

48-
assert output.tolist() == test['expected']
63+
assert output.tolist() == expected.tolist()

torch_scatter/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .add import scatter_add
22
from .sub import scatter_sub
3+
from .mean import scatter_mean
34

45
__version__ = '1.0.0'
56

6-
__all__ = ['scatter_add', 'scatter_sub', '__version__']
7+
__all__ = ['scatter_add', 'scatter_sub', 'scatter_mean', '__version__']

torch_scatter/add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class ScatterAdd(Function):
77
@staticmethod
8-
def forward(ctx, out, src, index, dim=-1):
8+
def forward(ctx, out, src, index, dim):
99
ctx.mark_dirty(out)
1010
ctx.save_for_backward(index)
1111
return out.scatter_add_(dim, index, src)
@@ -86,5 +86,5 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
8686
2 4 4 0 0 0
8787
[torch.FloatTensor of size 2x6]
8888
"""
89-
out, index = gen(src, index, dim, out, dim_size, fill_value)
89+
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
9090
return ScatterAdd.apply(out, src, index, dim)

torch_scatter/mean.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from torch.autograd import Function
2+
3+
from .utils.ffi import get_func
4+
from .utils.gen import gen
5+
6+
7+
class ScatterMean(Function):
8+
@staticmethod
9+
def forward(ctx, out, src, index, dim):
10+
ctx.mark_dirty(out)
11+
12+
count = src.new_zeros(out.size())
13+
func = get_func('scatter_mean', src)
14+
func(dim, out, index, src, count)
15+
count[count == 0] = 1
16+
out /= count
17+
18+
ctx.save_for_backward(index, count)
19+
20+
return out
21+
22+
@staticmethod
23+
def backward(ctx, grad_out):
24+
index, count = ctx.saved_variables
25+
26+
grad_src = None
27+
if ctx.needs_input_grad[1]:
28+
grad_src = grad_out[index] / count[index]
29+
30+
return None, grad_src, None, None
31+
32+
33+
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
34+
r"""
35+
|
36+
37+
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
38+
master/docs/source/_figures/mean.svg?sanitize=true
39+
:align: center
40+
:width: 400px
41+
42+
|
43+
44+
Averages all values from the :attr:`src` tensor into :attr:`out` at the
45+
indices specified in the :attr:`index` tensor along an given axis
46+
:attr:`dim`.If multiple indices reference the same location, their
47+
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
48+
49+
For one-dimensional tensors, the operation computes
50+
51+
.. math::
52+
\mathrm{out}_i = \mathrm{out}_i + \frac{1}{N_i} \cdot
53+
\sum_j \mathrm{src}_j
54+
55+
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i` and
56+
:math:`N_i` indicates the number of indices referencing :math:`i`.
57+
58+
Args:
59+
src (Tensor): The source tensor.
60+
index (LongTensor): The indices of elements to scatter.
61+
dim (int, optional): The axis along which to index.
62+
(default: :obj:`-1`)
63+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
64+
dim_size (int, optional): If :attr:`out` is not given, automatically
65+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
66+
If :attr:`dim_size` is not given, a minimal sized output tensor is
67+
returned. (default: :obj:`None`)
68+
fill_value (int, optional): If :attr:`out` is not given, automatically
69+
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
70+
71+
:rtype: :class:`Tensor`
72+
73+
.. testsetup::
74+
75+
import torch
76+
77+
.. testcode::
78+
79+
from torch_scatter import scatter_mean
80+
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
81+
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
82+
out = src.new_zeros((2, 6))
83+
out = scatter_mean(src, index, out=out)
84+
print(out)
85+
86+
.. testoutput::
87+
88+
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
89+
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
90+
[torch.FloatTensor of size 2x6]
91+
"""
92+
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
93+
return ScatterMean.apply(out, src, index, dim)

torch_scatter/utils/__init__.py

Whitespace-only changes.

torch_scatter/utils/ffi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .._ext import ffi
2+
3+
4+
def get_func(name, tensor):
5+
name += '_'
6+
name += 'cuda_' if tensor.is_cuda else ''
7+
name += tensor.type().split('.')[-1][:-6]
8+
return getattr(ffi, name)

torch_scatter/utils/gen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33

44
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
5+
dim = range(src.dim())[dim] # Get real dim value.
6+
57
# Automatically expand index tensor to the right dimensions.
68
if index.dim() == 1:
79
index_size = [*repeat(1, src.dim())]
@@ -15,4 +17,4 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1517
out_size[dim] = dim_size
1618
out = src.new_full(out_size, fill_value)
1719

18-
return out, index
20+
return src, out, index, dim

0 commit comments

Comments
 (0)