We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c38e20a commit a2a85feCopy full SHA for a2a85fe
torch_scatter/composite/logsumexp.py
@@ -37,4 +37,5 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
37
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
38
dim_size)
39
40
- return sum_per_index.add_(eps).log_().add_(max_value_per_index)
+ out = sum_per_index.add_(eps).log_().add_(max_value_per_index)
41
+ return out.nan_to_num_(neginf=0.0)
0 commit comments