Skip to content

Commit d790f1c

Browse files
authored
Merge pull request #77 from mallamanis/master
Implement scatter_logsumexp, scatter_softmax, scatter_log_softmax
2 parents 78a5549 + 62c6122 commit d790f1c

File tree

10 files changed

+244
-0
lines changed

10 files changed

+244
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ The package consists of the following operations:
3434
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
3535
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
3636
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
37+
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
38+
39+
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
40+
41+
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
42+
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
3743

3844
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
3945

docs/source/composite/softmax.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Scatter Softmax
2+
===============
3+
4+
.. automodule:: torch_scatter.composite
5+
:noindex:
6+
7+
.. autofunction:: scatter_softmax
8+
9+
.. autofunction:: scatter_log_softmax
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Scatter LogSumExp
2+
=================
3+
4+
.. automodule:: torch_scatter
5+
:noindex:
6+
7+
.. autofunction:: scatter_logsumexp

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
1414
:caption: Package reference
1515

1616
functions/*
17+
composite/*
1718

1819
Indices and tables
1920
==================

test/composite/test_softmax.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from itertools import product
2+
3+
import pytest
4+
import torch
5+
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
6+
7+
from test.utils import devices, tensor, grad_dtypes
8+
9+
10+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
11+
def test_softmax(dtype, device):
12+
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
13+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
14+
15+
out = scatter_softmax(src, index)
16+
17+
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
18+
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
19+
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
20+
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
21+
dim=-1)
22+
23+
expected = torch.stack([
24+
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
25+
], dim=0)
26+
27+
assert torch.allclose(out, expected)
28+
29+
30+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
31+
def test_log_softmax(dtype, device):
32+
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
33+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
34+
35+
out = scatter_log_softmax(src, index)
36+
37+
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
38+
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
39+
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
40+
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
41+
dim=-1)
42+
43+
expected = torch.stack([
44+
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
45+
], dim=0)
46+
47+
assert torch.allclose(out, expected)

test/test_logsumexp.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from itertools import product
2+
3+
import torch
4+
import pytest
5+
from torch_scatter import scatter_logsumexp
6+
7+
from .utils import devices, tensor, grad_dtypes
8+
9+
10+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
11+
def test_logsumexp(dtype, device):
12+
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
13+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
14+
15+
out = scatter_logsumexp(src, index)
16+
17+
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
18+
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
19+
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
20+
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
21+
out4 = torch.tensor(-1, dtype=dtype)
22+
23+
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
24+
assert torch.allclose(out, expected)

torch_scatter/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .std import scatter_std
77
from .max import scatter_max
88
from .min import scatter_min
9+
from .logsumexp import scatter_logsumexp
10+
import torch_scatter.composite
911

1012
__version__ = '1.3.2'
1113

@@ -18,5 +20,7 @@
1820
'scatter_std',
1921
'scatter_max',
2022
'scatter_min',
23+
'scatter_logsumexp',
24+
'torch_scatter',
2125
'__version__',
2226
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .softmax import scatter_log_softmax, scatter_softmax
2+
3+
__all__ = [
4+
'scatter_softmax',
5+
'scatter_log_softmax',
6+
]

torch_scatter/composite/softmax.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
3+
from torch_scatter import scatter_add, scatter_max
4+
5+
6+
def scatter_softmax(src, index, dim=-1, eps=1e-12):
7+
r"""
8+
Softmax operation over all values in :attr:`src` tensor that share indices
9+
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
10+
11+
For one-dimensional tensors, the operation computes
12+
13+
.. math::
14+
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
15+
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
16+
17+
where :math:`\sum_j` is over :math:`j` such that
18+
:math:`\mathrm{index}_j = i`.
19+
20+
Args:
21+
src (Tensor): The source tensor.
22+
index (LongTensor): The indices of elements to scatter.
23+
dim (int, optional): The axis along which to index.
24+
(default: :obj:`-1`)
25+
eps (float, optional): Small value to ensure numerical stability.
26+
(default: :obj:`1e-12`)
27+
28+
:rtype: :class:`Tensor`
29+
"""
30+
if not torch.is_floating_point(src):
31+
raise ValueError('`scatter_softmax` can only be computed over tensors '
32+
'with floating point data types.')
33+
34+
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
35+
max_per_src_element = max_value_per_index.gather(dim, index)
36+
37+
recentered_scores = src - max_per_src_element
38+
recentered_scores_exp = recentered_scores.exp()
39+
40+
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
41+
normalizing_constants = (sum_per_index + eps).gather(dim, index)
42+
43+
return recentered_scores_exp / normalizing_constants
44+
45+
46+
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
47+
r"""
48+
Log-softmax operation over all values in :attr:`src` tensor that share
49+
indices specified in the :attr:`index` tensor along a given axis
50+
:attr:`dim`.
51+
52+
For one-dimensional tensors, the operation computes
53+
54+
.. math::
55+
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
56+
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
57+
\right)
58+
59+
where :math:`\sum_j` is over :math:`j` such that
60+
:math:`\mathrm{index}_j = i`.
61+
62+
Args:
63+
src (Tensor): The source tensor.
64+
index (LongTensor): The indices of elements to scatter.
65+
dim (int, optional): The axis along which to index.
66+
(default: :obj:`-1`)
67+
eps (float, optional): Small value to ensure numerical stability.
68+
(default: :obj:`1e-12`)
69+
70+
:rtype: :class:`Tensor`
71+
"""
72+
if not torch.is_floating_point(src):
73+
raise ValueError('`scatter_log_softmax` can only be computed over '
74+
'tensors with floating point data types.')
75+
76+
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
77+
max_per_src_element = max_value_per_index.gather(dim, index)
78+
79+
recentered_scores = src - max_per_src_element
80+
81+
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
82+
dim=dim)
83+
84+
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
85+
86+
return recentered_scores - normalizing_constants

torch_scatter/logsumexp.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
3+
from . import scatter_add, scatter_max
4+
5+
6+
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
7+
fill_value=None, eps=1e-12):
8+
r"""Fills :attr:`out` with the log of summed exponentials of all values
9+
from the :attr:`src` tensor at the indices specified in the :attr:`index`
10+
tensor along a given axis :attr:`dim`.
11+
If multiple indices reference the same location, their
12+
**exponential contributions add**
13+
(`cf.` :meth:`~torch_scatter.scatter_add`).
14+
15+
For one-dimensional tensors, the operation computes
16+
17+
.. math::
18+
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
19+
\exp(\mathrm{src}_j) \right)
20+
21+
where :math:`\sum_j` is over :math:`j` such that
22+
:math:`\mathrm{index}_j = i`.
23+
24+
Args:
25+
src (Tensor): The source tensor.
26+
index (LongTensor): The indices of elements to scatter.
27+
dim (int, optional): The axis along which to index.
28+
(default: :obj:`-1`)
29+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
30+
dim_size (int, optional): If :attr:`out` is not given, automatically
31+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
32+
If :attr:`dim_size` is not given, a minimal sized output tensor is
33+
returned. (default: :obj:`None`)
34+
fill_value (int, optional): If :attr:`out` is not given, automatically
35+
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
36+
eps (float, optional): Small value to ensure numerical stability.
37+
(default: :obj:`1e-12`)
38+
39+
:rtype: :class:`Tensor`
40+
"""
41+
if not torch.is_floating_point(src):
42+
raise ValueError('`scatter_logsumexp` can only be computed over '
43+
'tensors with floating point data types.')
44+
45+
max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
46+
fill_value)
47+
max_per_src_element = max_value_per_index.gather(dim, index)
48+
recentered_scores = src - max_per_src_element
49+
out = (out - max_per_src_element).exp() if out is not None else None
50+
51+
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
52+
dim_size, fill_value=0)
53+
54+
return torch.log(sum_per_index + eps) + max_value_per_index

0 commit comments

Comments
 (0)