Skip to content

Commit dd50d35

Browse files
committed
A first round of implementation of scatter_logsumexp/softmax/logsoftmax ops.
1 parent 78a5549 commit dd50d35

File tree

5 files changed

+166
-0
lines changed

5 files changed

+166
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Scatter LogSumExp
2+
===========
3+
4+
.. automodule:: torch_scatter
5+
:noindex:
6+
7+
.. autofunction:: scatter_logsumexp

torch_scatter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .std import scatter_std
77
from .max import scatter_max
88
from .min import scatter_min
9+
from .logsumexp import scatter_logsumexp
910

1011
__version__ = '1.3.2'
1112

@@ -18,5 +19,6 @@
1819
'scatter_std',
1920
'scatter_max',
2021
'scatter_min',
22+
'scatter_logsumexp',
2123
'__version__',
2224
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .softmax import scatter_log_softmax, scatter_softmax
2+
3+
__all__ = [
4+
'scatter_softmax',
5+
'scatter_log_softmax'
6+
]

torch_scatter/composite/softmax.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
3+
from torch_scatter.logsumexp import _scatter_logsumexp
4+
5+
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
6+
r"""
7+
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
8+
indices specified in the :attr:`index` tensor along a given axis
9+
:attr:`dim`.If multiple indices reference the same location, their
10+
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
11+
12+
For one-dimensional tensors, the operation computes
13+
14+
.. math::
15+
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
16+
17+
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
18+
:math:`\mathrm{index}_j = i`.
19+
20+
Compute a numerically safe log softmax operation
21+
from the :attr:`src` tensor into :attr:`out` at the indices
22+
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
23+
each value in :attr:`src`, its output index is specified by its index in
24+
:attr:`input` for dimensions outside of :attr:`dim` and by the
25+
corresponding value in :attr:`index` for dimension :attr:`dim`.
26+
27+
Args:
28+
src (Tensor): The source tensor.
29+
index (LongTensor): The indices of elements to scatter.
30+
dim (int, optional): The axis along which to index.
31+
(default: :obj:`-1`)
32+
dim_size (int, optional): If :attr:`out` is not given, automatically
33+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
34+
If :attr:`dim_size` is not given, a minimal sized output tensor is
35+
returned. (default: :obj:`None`)
36+
fill_value (int, optional): If :attr:`out` is not given, automatically
37+
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
38+
the output tensor is filled with the smallest possible value of
39+
:obj:`src.dtype`. (default: :obj:`None`)
40+
41+
:rtype: :class:`Tensor`
42+
"""
43+
per_index_logsumexp, recentered_src = _scatter_logsumexp(src, index, dim=dim, dim_size=dim_size)
44+
return recentered_src - per_index_logsumexp.gather(dim, index)
45+
46+
47+
def scatter_softmax(src, index, dim=-1, dim_size=None):
48+
r"""
49+
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
50+
indices specified in the :attr:`index` tensor along a given axis
51+
:attr:`dim`. If multiple indices reference the same location, their
52+
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
53+
54+
For one-dimensional tensors, the operation computes
55+
56+
.. math::
57+
\mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
58+
59+
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
60+
:math:`\mathrm{index}_j = i`.
61+
62+
Compute a numerically safe softmax operation
63+
from the :attr:`src` tensor into :attr:`out` at the indices
64+
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
65+
each value in :attr:`src`, its output index is specified by its index in
66+
:attr:`input` for dimensions outside of :attr:`dim` and by the
67+
corresponding value in :attr:`index` for dimension :attr:`dim`.
68+
69+
Args:
70+
src (Tensor): The source tensor.
71+
index (LongTensor): The indices of elements to scatter.
72+
dim (int, optional): The axis along which to index.
73+
(default: :obj:`-1`)
74+
dim_size (int, optional): If :attr:`out` is not given, automatically
75+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
76+
If :attr:`dim_size` is not given, a minimal sized output tensor is
77+
returned. (default: :obj:`None`)
78+
fill_value (int, optional): If :attr:`out` is not given, automatically
79+
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
80+
the output tensor is filled with the smallest possible value of
81+
:obj:`src.dtype`. (default: :obj:`None`)
82+
83+
:rtype: :class:`Tensor`
84+
"""
85+
return scatter_log_softmax(src, index, dim, dim_size).exp()

torch_scatter/logsumexp.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
3+
from . import scatter_add, scatter_max
4+
5+
EPSILON = 1e-16
6+
7+
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
8+
if not torch.is_floating_point(src):
9+
raise ValueError('logsumexp can be computed over tensors floating point data types.')
10+
11+
if fill_value is None:
12+
fill_value = torch.finfo(src.dtype).min
13+
14+
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
15+
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
16+
max_per_src_element = max_value_per_index.gather(dim, index)
17+
18+
recentered_scores = src - max_per_src_element
19+
20+
sum_per_index = scatter_add(
21+
src=recentered_scores.exp(),
22+
index=index,
23+
dim=dim,
24+
out=(src - max_per_src_element).exp() if out is not None else None,
25+
dim_size=dim_size,
26+
fill_value=fill_value,
27+
)
28+
return torch.log(sum_per_index + EPSILON) + max_value_per_index, recentered_scores
29+
30+
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
31+
r"""
32+
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
33+
indices specified in the :attr:`index` tensor along a given axis
34+
:attr:`dim`. If multiple indices reference the same location, their
35+
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
36+
37+
For one-dimensional tensors, the operation computes
38+
39+
.. math::
40+
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right)
41+
42+
Compute a numerically safe logsumexp operation
43+
from the :attr:`src` tensor into :attr:`out` at the indices
44+
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
45+
each value in :attr:`src`, its output index is specified by its index in
46+
:attr:`input` for dimensions outside of :attr:`dim` and by the
47+
corresponding value in :attr:`index` for dimension :attr:`dim`.
48+
49+
Args:
50+
src (Tensor): The source tensor.
51+
index (LongTensor): The indices of elements to scatter.
52+
dim (int, optional): The axis along which to index.
53+
(default: :obj:`-1`)
54+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
55+
dim_size (int, optional): If :attr:`out` is not given, automatically
56+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
57+
If :attr:`dim_size` is not given, a minimal sized output tensor is
58+
returned. (default: :obj:`None`)
59+
fill_value (int, optional): If :attr:`out` is not given, automatically
60+
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
61+
the output tensor is filled with the smallest possible value of
62+
:obj:`src.dtype`. (default: :obj:`None`)
63+
64+
:rtype: :class:`Tensor`
65+
"""
66+
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value)[0]

0 commit comments

Comments
 (0)