Skip to content

Commit 96aa2e3

Browse files
authored
Fix logsumexp when out is passed (#445)
* update * update * update
1 parent 521d26f commit 96aa2e3

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

test/composite/test_logsumexp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,12 @@ def test_logsumexp():
2929

3030
jit = torch.jit.script(scatter_logsumexp)
3131
assert jit(inputs, index).tolist() == outputs.tolist()
32+
33+
34+
def test_logsumexp_out():
35+
src = torch.tensor([-1.0, -50.0])
36+
index = torch.tensor([0, 0])
37+
out = torch.tensor([-10.0, -10.0])
38+
39+
scatter_logsumexp(src=src, index=index, out=out)
40+
assert out.allclose(torch.tensor([-0.9999, -10.0]), atol=1e-4)
Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from typing import Optional
22

33
import torch
4-
from torch_scatter import scatter_sum, scatter_max
5-
4+
from torch_scatter import scatter_max, scatter_sum
65
from torch_scatter.utils import broadcast
76

87

9-
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
10-
out: Optional[torch.Tensor] = None,
11-
dim_size: Optional[int] = None,
12-
eps: float = 1e-12) -> torch.Tensor:
8+
def scatter_logsumexp(
9+
src: torch.Tensor,
10+
index: torch.Tensor,
11+
dim: int = -1,
12+
out: Optional[torch.Tensor] = None,
13+
dim_size: Optional[int] = None,
14+
eps: float = 1e-12,
15+
) -> torch.Tensor:
1316
if not torch.is_floating_point(src):
1417
raise ValueError('`scatter_logsumexp` can only be computed over '
1518
'tensors with floating point data types.')
@@ -24,18 +27,30 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2427

2528
size = list(src.size())
2629
size[dim] = dim_size
27-
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
28-
device=src.device)
30+
max_value_per_index = torch.full(
31+
size,
32+
fill_value=float('-inf'),
33+
dtype=src.dtype,
34+
device=src.device,
35+
)
2936
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
3037
max_per_src_element = max_value_per_index.gather(dim, index)
3138
recentered_score = src - max_per_src_element
3239
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))
3340

41+
orig_out: Optional[torch.Tensor] = None
3442
if out is not None:
43+
orig_out = out.clone()
3544
out = out.sub_(max_value_per_index).exp_()
3645

3746
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
3847
dim_size)
3948

4049
out = sum_per_index.add_(eps).log_().add_(max_value_per_index)
41-
return out.nan_to_num_(neginf=0.0)
50+
51+
if orig_out is None:
52+
return out.nan_to_num_(neginf=0.0)
53+
54+
mask = ~out.isfinite()
55+
out[mask] = orig_out[mask]
56+
return out

0 commit comments

Comments
 (0)