Skip to content

Commit 7b14c67

Browse files
committed
Bug fixes, testing and other minor edits.
* `log_softmax` has now stand-alone to save one operation (and fix a bug). * `softmax` is implemented in a similar stand-alone way. * Address some PR comments.
1 parent 0ef9260 commit 7b14c67

File tree

4 files changed

+136
-32
lines changed

4 files changed

+136
-32
lines changed

test/test_logsumexp.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from itertools import product
2+
3+
import torch
4+
import pytest
5+
from torch_scatter import scatter_max, scatter_logsumexp
6+
7+
from .utils import devices, tensor
8+
9+
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
10+
11+
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
12+
def test_logsumexp(dtype, device):
13+
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
14+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
15+
16+
out = scatter_logsumexp(src, index)
17+
18+
idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist()
19+
idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
20+
idx2 = 7 # Single element
21+
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
22+
idx4 = -1 # logsumexp with -inf is the identity
23+
24+
assert out.tolist() == [idx0, idx1, idx2, idx3, idx4]

test/test_softmax.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from itertools import product
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
7+
8+
from .utils import devices, tensor
9+
10+
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
11+
12+
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
13+
def test_log_softmax(dtype, device):
14+
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
15+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
16+
17+
out = scatter_log_softmax(src, index)
18+
19+
# Expected results per index
20+
idx0 = [np.log(0.5), np.log(0.5)]
21+
idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
22+
idx2 = 0.0 # Single element, has logprob=0
23+
# index=3 is empty. Should not matter.
24+
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
25+
26+
np.testing.assert_allclose(
27+
out.tolist(),
28+
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
29+
rtol=1e-05, atol=1e-10
30+
)
31+
32+
33+
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
34+
def test_softmax(dtype, device):
35+
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
36+
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
37+
38+
out = scatter_softmax(src, index)
39+
40+
# Expected results per index
41+
idx0 = [0.5, 0.5]
42+
idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
43+
idx2 = 1 # Single element, has prob=1
44+
# index=3 is empty. Should not matter.
45+
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability
46+
47+
np.testing.assert_allclose(
48+
out.tolist(),
49+
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
50+
rtol=1e-05, atol=1e-10
51+
)

torch_scatter/composite/softmax.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from torch_scatter.logsumexp import _scatter_logsumexp
3+
from torch_scatter import scatter_add, scatter_max
44

55
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
66
r"""
@@ -12,7 +12,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
1212
For one-dimensional tensors, the operation computes
1313
1414
.. math::
15-
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
15+
\mathrm{out}_i = softmax(\mathrm{src}_i) =
16+
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
1617
1718
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
1819
:math:`\mathrm{index}_j = i`.
@@ -40,11 +41,26 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
4041
4142
:rtype: :class:`Tensor`
4243
"""
43-
per_index_logsumexp, recentered_src = _scatter_logsumexp(src, index, dim=dim, dim_size=dim_size)
44-
return recentered_src - per_index_logsumexp.gather(dim, index)
44+
if not torch.is_floating_point(src):
45+
raise ValueError('log_softmax can be computed only over tensors with floating point data types.')
4546

47+
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
48+
max_per_src_element = max_value_per_index.gather(dim, index)
4649

47-
def scatter_softmax(src, index, dim=-1, dim_size=None):
50+
recentered_scores = src - max_per_src_element
51+
52+
sum_per_index = scatter_add(
53+
src=recentered_scores.exp(),
54+
index=index,
55+
dim=dim,
56+
dim_size=dim_size
57+
)
58+
log_normalizing_constants = sum_per_index.log().gather(dim, index)
59+
60+
return recentered_scores - log_normalizing_constants
61+
62+
63+
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
4864
r"""
4965
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
5066
indices specified in the :attr:`index` tensor along a given axis
@@ -54,7 +70,8 @@ def scatter_softmax(src, index, dim=-1, dim_size=None):
5470
For one-dimensional tensors, the operation computes
5571
5672
.. math::
57-
\mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
73+
\mathrm{out}_i = softmax(\mathrm{src}_i) =
74+
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
5875
5976
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
6077
:math:`\mathrm{index}_j = i`.
@@ -82,4 +99,20 @@ def scatter_softmax(src, index, dim=-1, dim_size=None):
8299
83100
:rtype: :class:`Tensor`
84101
"""
85-
return scatter_log_softmax(src, index, dim, dim_size).exp()
102+
if not torch.is_floating_point(src):
103+
raise ValueError('softmax can be computed only over tensors with floating point data types.')
104+
105+
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
106+
max_per_src_element = max_value_per_index.gather(dim, index)
107+
108+
recentered_scores = src - max_per_src_element
109+
exped_recentered_scores = recentered_scores.exp()
110+
111+
sum_per_index = scatter_add(
112+
src=exped_recentered_scores,
113+
index=index,
114+
dim=dim,
115+
dim_size=dim_size
116+
)
117+
normalizing_constant = (sum_per_index + epsilon).gather(dim, index)
118+
return exped_recentered_scores / normalizing_constant

torch_scatter/logsumexp.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,6 @@
33
from . import scatter_add, scatter_max
44

55

6-
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
7-
if not torch.is_floating_point(src):
8-
raise ValueError('logsumexp can be computed over tensors floating point data types.')
9-
10-
if fill_value is None:
11-
fill_value = torch.finfo(src.dtype).min
12-
13-
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
14-
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
15-
max_per_src_element = max_value_per_index.gather(dim, index)
16-
17-
recentered_scores = src - max_per_src_element
18-
19-
sum_per_index = scatter_add(
20-
src=recentered_scores.exp(),
21-
index=index,
22-
dim=dim,
23-
out=(src - max_per_src_element).exp() if out is not None else None,
24-
dim_size=dim_size,
25-
fill_value=fill_value,
26-
)
27-
return torch.log(sum_per_index + epsilon) + max_value_per_index, recentered_scores
28-
29-
306
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
317
r"""
328
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
@@ -63,4 +39,24 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
6339
6440
:rtype: :class:`Tensor`
6541
"""
66-
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value, epsilon=epsilon)[0]
42+
if not torch.is_floating_point(src):
43+
raise ValueError('logsumexp can be computed over tensors with floating point data types.')
44+
45+
if fill_value is None:
46+
fill_value = torch.finfo(src.dtype).min
47+
48+
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
49+
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
50+
max_per_src_element = max_value_per_index.gather(dim, index)
51+
52+
recentered_scores = src - max_per_src_element
53+
54+
sum_per_index = scatter_add(
55+
src=recentered_scores.exp(),
56+
index=index,
57+
dim=dim,
58+
out=(out - max_per_src_element).exp() if out is not None else None,
59+
dim_size=dim_size,
60+
fill_value=0,
61+
)
62+
return torch.log(sum_per_index + epsilon) + max_value_per_index

0 commit comments

Comments
 (0)