Skip to content

Commit 0c12788

Browse files
committed
Address most flake8, pycodestyle errors.
1 parent 7b14c67 commit 0c12788

File tree

4 files changed

+35
-16
lines changed

4 files changed

+35
-16
lines changed

test/test_logsumexp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import torch
44
import pytest
5-
from torch_scatter import scatter_max, scatter_logsumexp
5+
from torch_scatter import scatter_logsumexp
66

77
from .utils import devices, tensor
88

99
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
1010

11+
1112
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
1213
def test_logsumexp(dtype, device):
1314
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)

test/test_softmax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
1111

12+
1213
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
1314
def test_log_softmax(dtype, device):
1415
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
@@ -48,4 +49,4 @@ def test_softmax(dtype, device):
4849
out.tolist(),
4950
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
5051
rtol=1e-05, atol=1e-10
51-
)
52+
)

torch_scatter/composite/softmax.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
from torch_scatter import scatter_add, scatter_max
44

5+
56
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
67
r"""
7-
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
8+
Numerical safe log-softmax of all values from
9+
the :attr:`src` tensor into :attr:`out` at the
810
indices specified in the :attr:`index` tensor along a given axis
911
:attr:`dim`.If multiple indices reference the same location, their
1012
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
1113
1214
For one-dimensional tensors, the operation computes
1315
1416
.. math::
15-
\mathrm{out}_i = softmax(\mathrm{src}_i) =
17+
\mathrm{out}_i = softmax(\mathrm{src}_i) =
1618
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
1719
1820
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
@@ -42,9 +44,12 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
4244
:rtype: :class:`Tensor`
4345
"""
4446
if not torch.is_floating_point(src):
45-
raise ValueError('log_softmax can be computed only over tensors with floating point data types.')
47+
raise ValueError('log_softmax can be computed only over '
48+
'tensors with floating point data types.')
4649

47-
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
50+
max_value_per_index, _ = scatter_max(src, index,
51+
dim=dim,
52+
dim_size=dim_size)
4853
max_per_src_element = max_value_per_index.gather(dim, index)
4954

5055
recentered_scores = src - max_per_src_element
@@ -62,15 +67,16 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
6267

6368
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
6469
r"""
65-
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
70+
Numerical safe log-softmax of all values from
71+
the :attr:`src` tensor into :attr:`out` at the
6672
indices specified in the :attr:`index` tensor along a given axis
6773
:attr:`dim`. If multiple indices reference the same location, their
6874
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
6975
7076
For one-dimensional tensors, the operation computes
7177
7278
.. math::
73-
\mathrm{out}_i = softmax(\mathrm{src}_i) =
79+
\mathrm{out}_i = softmax(\mathrm{src}_i) =
7480
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
7581
7682
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
@@ -100,9 +106,12 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
100106
:rtype: :class:`Tensor`
101107
"""
102108
if not torch.is_floating_point(src):
103-
raise ValueError('softmax can be computed only over tensors with floating point data types.')
109+
raise ValueError('softmax can be computed only over '
110+
'tensors with floating point data types.')
104111

105-
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
112+
max_value_per_index, _ = scatter_max(src, index,
113+
dim=dim,
114+
dim_size=dim_size)
106115
max_per_src_element = max_value_per_index.gather(dim, index)
107116

108117
recentered_scores = src - max_per_src_element

torch_scatter/logsumexp.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
from . import scatter_add, scatter_max
44

55

6-
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
6+
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
7+
fill_value=None, epsilon=1e-16):
78
r"""
8-
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
9+
Numerically safe logsumexp of all values from
10+
the :attr:`src` tensor into :attr:`out` at the
911
indices specified in the :attr:`index` tensor along a given axis
1012
:attr:`dim`. If multiple indices reference the same location, their
1113
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
1214
1315
For one-dimensional tensors, the operation computes
1416
1517
.. math::
16-
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right)
18+
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i)
19+
+ \sum_j \exp(\mathrm{src}_j) \right)
1720
1821
Compute a numerically safe logsumexp operation
1922
from the :attr:`src` tensor into :attr:`out` at the indices
@@ -40,13 +43,18 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
4043
:rtype: :class:`Tensor`
4144
"""
4245
if not torch.is_floating_point(src):
43-
raise ValueError('logsumexp can be computed over tensors with floating point data types.')
46+
raise ValueError('logsumexp can only be computed over '
47+
'tensors with floating point data types.')
4448

4549
if fill_value is None:
4650
fill_value = torch.finfo(src.dtype).min
4751

48-
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
49-
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
52+
dim_size = out.shape[dim] \
53+
if dim_size is None and out is not None else dim_size
54+
55+
max_value_per_index, _ = scatter_max(src, index, dim=dim,
56+
out=out, dim_size=dim_size,
57+
fill_value=fill_value)
5058
max_per_src_element = max_value_per_index.gather(dim, index)
5159

5260
recentered_scores = src - max_per_src_element

0 commit comments

Comments
 (0)