Skip to content

Commit 28fce4c

Browse files
committed
use in-place exp() + remove eps
1 parent 6a1525b commit 28fce4c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch_scatter/composite/softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from torch_scatter.utils import broadcast
55

66

7-
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
8-
eps: float = 1e-12) -> torch.Tensor:
7+
def scatter_softmax(src: torch.Tensor, index: torch.Tensor,
8+
dim: int = -1) -> torch.Tensor:
99
if not torch.is_floating_point(src):
1010
raise ValueError('`scatter_softmax` can only be computed over tensors '
1111
'with floating point data types.')
@@ -16,10 +16,10 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
1616
max_per_src_element = max_value_per_index.gather(dim, index)
1717

1818
recentered_scores = src - max_per_src_element
19-
recentered_scores_exp = recentered_scores.exp()
19+
recentered_scores_exp = recentered_scores.exp_()
2020

2121
sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
22-
normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
22+
normalizing_constants = sum_per_index.gather(dim, index)
2323

2424
return recentered_scores_exp.div(normalizing_constants)
2525

0 commit comments

Comments
 (0)