Skip to content

Commit d367c0b

Browse files
committed
scatter div
1 parent 4b654c2 commit d367c0b

File tree

8 files changed

+121
-14
lines changed

8 files changed

+121
-14
lines changed

test/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .utils import devices
99

10-
funcs = ['add', 'sub', 'mul', 'mean']
10+
funcs = ['add', 'sub', 'mul', 'div', 'mean']
1111
indices = [2, 0, 1, 1, 0]
1212

1313

test/test_forward.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@
4848
'dim': 0,
4949
'fill_value': 1,
5050
'expected': [[5, 6], [8, 15]]
51+
}, {
52+
'name': 'div',
53+
'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
54+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
55+
'dim': -1,
56+
'fill_value': 1,
57+
'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]]
58+
}, {
59+
'name': 'div',
60+
'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
61+
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
62+
'dim': 0,
63+
'fill_value': 1,
64+
'expected': [[0.25, 0.25], [0.125, 0.5]]
5165
}, {
5266
'name': 'mean',
5367
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],

torch_scatter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from .add import scatter_add
22
from .sub import scatter_sub
33
from .mul import scatter_mul
4+
from .div import scatter_div
45
from .mean import scatter_mean
56

67
__version__ = '1.0.0'
78

89
__all__ = [
9-
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_mean', '__version__'
10+
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean',
11+
'__version__'
1012
]

torch_scatter/add.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
5050
.. math::
5151
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
5252
53-
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
53+
where :math:`\sum` is over :math:`j` such that
54+
:math:`\mathrm{index}_j = i`.
5455
5556
Args:
5657
src (Tensor): The source tensor.

torch_scatter/div.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from torch.autograd import Function
2+
3+
from .utils.ffi import get_func
4+
from .utils.gen import gen
5+
6+
7+
class ScatterDiv(Function):
8+
@staticmethod
9+
def forward(ctx, out, src, index, dim):
10+
func = get_func('scatter_div', src)
11+
func(dim, out, index, src)
12+
13+
ctx.mark_dirty(out)
14+
ctx.save_for_backward(out, src, index)
15+
16+
return out
17+
18+
@staticmethod
19+
def backward(ctx, grad_out):
20+
out, src, index = ctx.saved_variables
21+
22+
grad_src = None
23+
if ctx.needs_input_grad[1]:
24+
grad_src = -(out * grad_out)[index] / src
25+
26+
return None, grad_src, None, None
27+
28+
29+
def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
30+
r"""
31+
|
32+
33+
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
34+
master/docs/source/_figures/div.svg?sanitize=true
35+
:align: center
36+
:width: 400px
37+
38+
|
39+
40+
Divides all values from the :attr:`src` tensor into :attr:`out` at the
41+
indices specified in the :attr:`index` tensor along an given axis
42+
:attr:`dim`.If multiple indices reference the same location, their
43+
**contributions divide** (`cf.` :meth:`~torch_scatter.scatter_add`).
44+
45+
For one-dimensional tensors, the operation computes
46+
47+
.. math::
48+
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j
49+
\frac{1}{\mathrm{src}_j}
50+
51+
where :math:`\prod` is over :math:`j` such that
52+
:math:`\mathrm{index}_j = i`.
53+
54+
Args:
55+
src (Tensor): The source tensor.
56+
index (LongTensor): The indices of elements to scatter.
57+
dim (int, optional): The axis along which to index.
58+
(default: :obj:`-1`)
59+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
60+
dim_size (int, optional): If :attr:`out` is not given, automatically
61+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
62+
If :attr:`dim_size` is not given, a minimal sized output tensor is
63+
returned. (default: :obj:`None`)
64+
fill_value (int, optional): If :attr:`out` is not given, automatically
65+
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
66+
67+
:rtype: :class:`Tensor`
68+
69+
.. testsetup::
70+
71+
import torch
72+
73+
.. testcode::
74+
75+
from torch_scatter import scatter_div
76+
src = torch.tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
77+
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
78+
out = src.new_ones((2, 6))
79+
out = scatter_div(src, index, out=out)
80+
print(out)
81+
82+
.. testoutput::
83+
84+
1.0000 1.0000 0.2500 0.3333 0.2500 1.0000
85+
0.5000 0.2500 0.1667 1.0000 1.0000 1.0000
86+
[torch.FloatTensor of size 2x6]
87+
"""
88+
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
89+
return ScatterDiv.apply(out, src, index, dim)

torch_scatter/mean.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
5151
\mathrm{out}_i = \mathrm{out}_i + \frac{1}{N_i} \cdot
5252
\sum_j \mathrm{src}_j
5353
54-
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i` and
55-
:math:`N_i` indicates the number of indices referencing :math:`i`.
54+
where :math:`\sum` is over :math:`j` such that :math:`\mathrm{index}_j = i`
55+
add :math:`N_i` indicates the number of indices referencing :math:`i`.
5656
5757
Args:
5858
src (Tensor): The source tensor.

torch_scatter/mul.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def forward(ctx, out, src, index, dim):
1010
func = get_func('scatter_mul', src)
1111
func(dim, out, index, src)
1212

13-
ctx.dim = dim
1413
ctx.mark_dirty(out)
1514
ctx.save_for_backward(out, src, index)
1615

@@ -48,7 +47,8 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
4847
.. math::
4948
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j \mathrm{src}_j
5049
51-
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
50+
where :math:`\prod` is over :math:`j` such that
51+
:math:`\mathrm{index}_j = i`.
5252
5353
Args:
5454
src (Tensor): The source tensor.
@@ -71,17 +71,17 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
7171
7272
.. testcode::
7373
74-
from torch_scatter import scatter_mean
75-
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
74+
from torch_scatter import scatter_mul
75+
src = torch.tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
7676
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
77-
out = src.new_zeros((2, 6))
78-
out = scatter_mean(src, index, out=out)
77+
out = src.new_ones((2, 6))
78+
out = scatter_mul(src, index, out=out)
7979
print(out)
8080
8181
.. testoutput::
8282
83-
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
84-
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
83+
1 1 4 3 6 0
84+
6 4 8 1 1 1
8585
[torch.FloatTensor of size 2x6]
8686
"""
8787
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)

torch_scatter/sub.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
2222
.. math::
2323
\mathrm{out}_i = \mathrm{out}_i - \sum_j \mathrm{src}_j
2424
25-
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
25+
where :math:`\sum` is over :math:`j` such that
26+
:math:`\mathrm{index}_j = i`.
2627
2728
Args:
2829
src (Tensor): The source tensor.

0 commit comments

Comments
 (0)