Skip to content

Commit 62c6122

Browse files
committed
clean up code base / added new functions to readme / added docs for softmax functions
1 parent d63eb9c commit 62c6122

File tree

9 files changed

+129
-187
lines changed

9 files changed

+129
-187
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ The package consists of the following operations:
3434
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
3535
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
3636
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
37+
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
38+
39+
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
40+
41+
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
42+
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
3743

3844
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
3945

docs/source/composite/softmax.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Scatter Softmax
2+
===============
3+
4+
.. automodule:: torch_scatter.composite
5+
:noindex:
6+
7+
.. autofunction:: scatter_softmax
8+
9+
.. autofunction:: scatter_log_softmax

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
1414
:caption: Package reference
1515

1616
functions/*
17+
composite/*
1718

1819
Indices and tables
1920
==================

test/composite/test_softmax.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from itertools import product
2+
3+
import pytest
4+
import torch
5+
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
6+
7+
from test.utils import devices, tensor, grad_dtypes
8+
9+
10+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
11+
def test_softmax(dtype, device):
12+
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
13+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
14+
15+
out = scatter_softmax(src, index)
16+
17+
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
18+
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
19+
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
20+
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
21+
dim=-1)
22+
23+
expected = torch.stack([
24+
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
25+
], dim=0)
26+
27+
assert torch.allclose(out, expected)
28+
29+
30+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
31+
def test_log_softmax(dtype, device):
32+
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
33+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
34+
35+
out = scatter_log_softmax(src, index)
36+
37+
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
38+
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
39+
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
40+
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
41+
dim=-1)
42+
43+
expected = torch.stack([
44+
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
45+
], dim=0)
46+
47+
assert torch.allclose(out, expected)

test/test_logsumexp.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,21 @@
44
import pytest
55
from torch_scatter import scatter_logsumexp
66

7-
from .utils import devices, tensor
7+
from .utils import devices, tensor, grad_dtypes
88

9-
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
109

11-
12-
@pytest.mark.parametrize('dtype,device',
13-
product(SUPPORTED_FLOAT_DTYPES, devices))
10+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1411
def test_logsumexp(dtype, device):
1512
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
1613
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
1714

1815
out = scatter_logsumexp(src, index)
1916

20-
idx0 = torch.logsumexp(
21-
torch.tensor([0.5, 0.5], dtype=dtype),
22-
dim=-1).tolist()
23-
idx1 = torch.logsumexp(
24-
torch.tensor([0, -2.1, 3.2], dtype=dtype),
25-
dim=-1).tolist()
26-
idx2 = 7 # Single element
27-
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
28-
idx4 = -1 # logsumexp with -inf is the identity
17+
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
18+
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
19+
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
20+
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
21+
out4 = torch.tensor(-1, dtype=dtype)
2922

30-
assert out.tolist() == [idx0, idx1, idx2, idx3, idx4]
23+
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
24+
assert torch.allclose(out, expected)

test/test_softmax.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

torch_scatter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .max import scatter_max
88
from .min import scatter_min
99
from .logsumexp import scatter_logsumexp
10+
import torch_scatter.composite
1011

1112
__version__ = '1.3.2'
1213

@@ -20,5 +21,6 @@
2021
'scatter_max',
2122
'scatter_min',
2223
'scatter_logsumexp',
24+
'torch_scatter',
2325
'__version__',
2426
]

torch_scatter/composite/softmax.py

Lines changed: 33 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,125 +3,84 @@
33
from torch_scatter import scatter_add, scatter_max
44

55

6-
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
6+
def scatter_softmax(src, index, dim=-1, eps=1e-12):
77
r"""
8-
Numerical safe log-softmax of all values from
9-
the :attr:`src` tensor into :attr:`out` at the
10-
indices specified in the :attr:`index` tensor along a given axis
11-
:attr:`dim`.If multiple indices reference the same location, their
12-
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
8+
Softmax operation over all values in :attr:`src` tensor that share indices
9+
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
1310
1411
For one-dimensional tensors, the operation computes
1512
1613
.. math::
17-
\mathrm{out}_i = softmax(\mathrm{src}_i) =
18-
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
14+
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
15+
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
1916
20-
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
17+
where :math:`\sum_j` is over :math:`j` such that
2118
:math:`\mathrm{index}_j = i`.
2219
23-
Compute a numerically safe log softmax operation
24-
from the :attr:`src` tensor into :attr:`out` at the indices
25-
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
26-
each value in :attr:`src`, its output index is specified by its index in
27-
:attr:`input` for dimensions outside of :attr:`dim` and by the
28-
corresponding value in :attr:`index` for dimension :attr:`dim`.
29-
3020
Args:
3121
src (Tensor): The source tensor.
3222
index (LongTensor): The indices of elements to scatter.
3323
dim (int, optional): The axis along which to index.
3424
(default: :obj:`-1`)
35-
dim_size (int, optional): If :attr:`out` is not given, automatically
36-
create output with size :attr:`dim_size` at dimension :attr:`dim`.
37-
If :attr:`dim_size` is not given, a minimal sized output tensor is
38-
returned. (default: :obj:`None`)
39-
fill_value (int, optional): If :attr:`out` is not given, automatically
40-
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
41-
the output tensor is filled with the smallest possible value of
42-
:obj:`src.dtype`. (default: :obj:`None`)
25+
eps (float, optional): Small value to ensure numerical stability.
26+
(default: :obj:`1e-12`)
4327
4428
:rtype: :class:`Tensor`
4529
"""
4630
if not torch.is_floating_point(src):
47-
raise ValueError('log_softmax can be computed only over '
48-
'tensors with floating point data types.')
31+
raise ValueError('`scatter_softmax` can only be computed over tensors '
32+
'with floating point data types.')
4933

50-
max_value_per_index, _ = scatter_max(src, index,
51-
dim=dim,
52-
dim_size=dim_size)
34+
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
5335
max_per_src_element = max_value_per_index.gather(dim, index)
5436

5537
recentered_scores = src - max_per_src_element
38+
recentered_scores_exp = recentered_scores.exp()
5639

57-
sum_per_index = scatter_add(
58-
src=recentered_scores.exp(),
59-
index=index,
60-
dim=dim,
61-
dim_size=dim_size
62-
)
63-
log_normalizing_constants = sum_per_index.log().gather(dim, index)
40+
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
41+
normalizing_constants = (sum_per_index + eps).gather(dim, index)
6442

65-
return recentered_scores - log_normalizing_constants
43+
return recentered_scores_exp / normalizing_constants
6644

6745

68-
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
46+
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
6947
r"""
70-
Numerical safe log-softmax of all values from
71-
the :attr:`src` tensor into :attr:`out` at the
48+
Log-softmax operation over all values in :attr:`src` tensor that share
7249
indices specified in the :attr:`index` tensor along a given axis
73-
:attr:`dim`. If multiple indices reference the same location, their
74-
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
50+
:attr:`dim`.
7551
7652
For one-dimensional tensors, the operation computes
7753
7854
.. math::
79-
\mathrm{out}_i = softmax(\mathrm{src}_i) =
80-
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
55+
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
56+
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
57+
\right)
8158
82-
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
59+
where :math:`\sum_j` is over :math:`j` such that
8360
:math:`\mathrm{index}_j = i`.
8461
85-
Compute a numerically safe softmax operation
86-
from the :attr:`src` tensor into :attr:`out` at the indices
87-
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
88-
each value in :attr:`src`, its output index is specified by its index in
89-
:attr:`input` for dimensions outside of :attr:`dim` and by the
90-
corresponding value in :attr:`index` for dimension :attr:`dim`.
91-
9262
Args:
9363
src (Tensor): The source tensor.
9464
index (LongTensor): The indices of elements to scatter.
9565
dim (int, optional): The axis along which to index.
9666
(default: :obj:`-1`)
97-
dim_size (int, optional): If :attr:`out` is not given, automatically
98-
create output with size :attr:`dim_size` at dimension :attr:`dim`.
99-
If :attr:`dim_size` is not given, a minimal sized output tensor is
100-
returned. (default: :obj:`None`)
101-
fill_value (int, optional): If :attr:`out` is not given, automatically
102-
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
103-
the output tensor is filled with the smallest possible value of
104-
:obj:`src.dtype`. (default: :obj:`None`)
67+
eps (float, optional): Small value to ensure numerical stability.
68+
(default: :obj:`1e-12`)
10569
10670
:rtype: :class:`Tensor`
10771
"""
10872
if not torch.is_floating_point(src):
109-
raise ValueError('softmax can be computed only over '
73+
raise ValueError('`scatter_log_softmax` can only be computed over '
11074
'tensors with floating point data types.')
11175

112-
max_value_per_index, _ = scatter_max(src, index,
113-
dim=dim,
114-
dim_size=dim_size)
76+
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
11577
max_per_src_element = max_value_per_index.gather(dim, index)
11678

11779
recentered_scores = src - max_per_src_element
118-
exped_recentered_scores = recentered_scores.exp()
119-
120-
sum_per_index = scatter_add(
121-
src=exped_recentered_scores,
122-
index=index,
123-
dim=dim,
124-
dim_size=dim_size
125-
)
126-
normalizing_constant = (sum_per_index + epsilon).gather(dim, index)
127-
return exped_recentered_scores / normalizing_constant
80+
81+
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
82+
dim=dim)
83+
84+
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
85+
86+
return recentered_scores - normalizing_constants

0 commit comments

Comments
 (0)