11from typing import Optional
22
33import torch
4- from torch_scatter import scatter_sum , scatter_max
5-
4+ from torch_scatter import scatter_max , scatter_sum
65from torch_scatter .utils import broadcast
76
87
9- def scatter_logsumexp (src : torch .Tensor , index : torch .Tensor , dim : int = - 1 ,
10- out : Optional [torch .Tensor ] = None ,
11- dim_size : Optional [int ] = None ,
12- eps : float = 1e-12 ) -> torch .Tensor :
8+ def scatter_logsumexp (
9+ src : torch .Tensor ,
10+ index : torch .Tensor ,
11+ dim : int = - 1 ,
12+ out : Optional [torch .Tensor ] = None ,
13+ dim_size : Optional [int ] = None ,
14+ eps : float = 1e-12 ,
15+ ) -> torch .Tensor :
1316 if not torch .is_floating_point (src ):
1417 raise ValueError ('`scatter_logsumexp` can only be computed over '
1518 'tensors with floating point data types.' )
@@ -24,18 +27,30 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2427
2528 size = list (src .size ())
2629 size [dim ] = dim_size
27- max_value_per_index = torch .full (size , float ('-inf' ), dtype = src .dtype ,
28- device = src .device )
30+ max_value_per_index = torch .full (
31+ size ,
32+ fill_value = float ('-inf' ),
33+ dtype = src .dtype ,
34+ device = src .device ,
35+ )
2936 scatter_max (src , index , dim , max_value_per_index , dim_size = dim_size )[0 ]
3037 max_per_src_element = max_value_per_index .gather (dim , index )
3138 recentered_score = src - max_per_src_element
3239 recentered_score .masked_fill_ (torch .isnan (recentered_score ), float ('-inf' ))
3340
41+ orig_out : Optional [torch .Tensor ] = None
3442 if out is not None :
43+ orig_out = out .clone ()
3544 out = out .sub_ (max_value_per_index ).exp_ ()
3645
3746 sum_per_index = scatter_sum (recentered_score .exp_ (), index , dim , out ,
3847 dim_size )
3948
4049 out = sum_per_index .add_ (eps ).log_ ().add_ (max_value_per_index )
41- return out .nan_to_num_ (neginf = 0.0 )
50+
51+ if orig_out is None :
52+ return out .nan_to_num_ (neginf = 0.0 )
53+
54+ mask = ~ out .isfinite ()
55+ out [mask ] = orig_out [mask ]
56+ return out
0 commit comments