Skip to content

Commit 9c7af8d

Browse files
committed
Move epsilon to an argument.
1 parent dd50d35 commit 9c7af8d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torch_scatter/logsumexp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
from . import scatter_add, scatter_max
44

5-
EPSILON = 1e-16
65

7-
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
6+
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
87
if not torch.is_floating_point(src):
98
raise ValueError('logsumexp can be computed over tensors floating point data types.')
109

@@ -25,9 +24,10 @@ def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=N
2524
dim_size=dim_size,
2625
fill_value=fill_value,
2726
)
28-
return torch.log(sum_per_index + EPSILON) + max_value_per_index, recentered_scores
27+
return torch.log(sum_per_index + epsilon) + max_value_per_index, recentered_scores
2928

30-
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
29+
30+
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
3131
r"""
3232
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
3333
indices specified in the :attr:`index` tensor along a given axis
@@ -63,4 +63,4 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
6363
6464
:rtype: :class:`Tensor`
6565
"""
66-
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value)[0]
66+
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value, epsilon=epsilon)[0]

0 commit comments

Comments
 (0)