Skip to content

Commit a2a85fe

Browse files
authored
update (#369)
1 parent c38e20a commit a2a85fe

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch_scatter/composite/logsumexp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3737
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
3838
dim_size)
3939

40-
return sum_per_index.add_(eps).log_().add_(max_value_per_index)
40+
out = sum_per_index.add_(eps).log_().add_(max_value_per_index)
41+
return out.nan_to_num_(neginf=0.0)

0 commit comments

Comments
 (0)