1212except OSError :
1313 warnings .warn ('Failed to load `scatter` binaries.' )
1414
15- def scatter_placeholder (src : torch .Tensor , index : torch .Tensor , dim : int ,
16- out : Optional [torch .Tensor ],
17- dim_size : Optional [int ]) -> torch .Tensor :
18- raise ImportError
19- return src
20-
2115 def scatter_with_arg_placeholder (src : torch .Tensor , index : torch .Tensor ,
2216 dim : int , out : Optional [torch .Tensor ],
2317 dim_size : Optional [int ]
2418 ) -> Tuple [torch .Tensor , torch .Tensor ]:
2519 raise ImportError
2620 return src , index
2721
28- torch .ops .torch_scatter .scatter_mean = scatter_placeholder
2922 torch .ops .torch_scatter .scatter_min = scatter_with_arg_placeholder
3023 torch .ops .torch_scatter .scatter_max = scatter_with_arg_placeholder
3124
@@ -37,11 +30,13 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3730 index = broadcast (index , src , dim )
3831 if out is None :
3932 size = src .size ()
40- if dim_size is None :
41- size [dim ] = int (index .max ()) + 1
42- else :
33+ if dim_size is not None :
4334 size [dim ] = dim_size
44- out = src .new_zeros (size )
35+ elif index .numel () == 0 :
36+ size [dim ] = 0
37+ else :
38+ size [dim ] = int (index .max ()) + 1
39+ out = torch .zeros (size , dtype = src .dtype , device = src .device )
4540 return out .scatter_add_ (dim , index , src )
4641 else :
4742 return out .scatter_add_ (dim , index , src )
@@ -58,7 +53,22 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5853def scatter_mean (src : torch .Tensor , index : torch .Tensor , dim : int = - 1 ,
5954 out : Optional [torch .Tensor ] = None ,
6055 dim_size : Optional [int ] = None ) -> torch .Tensor :
61- return torch .ops .torch_scatter .scatter_mean (src , index , dim , out , dim_size )
56+
57+ out = scatter_sum (src , index , dim , out , dim_size )
58+ dim_size = out .size (dim )
59+
60+ index_dim = dim
61+ if index_dim < 0 :
62+ index_dim = index_dim + src .dim ()
63+ if index .dim () <= dim :
64+ index_dim = index .dim () - 1
65+
66+ ones = torch .ones (index .size (), dtype = src .dtype , device = src .device )
67+ count = scatter_sum (ones , index , index_dim , None , dim_size )
68+ count .clamp_ (1 )
69+ count = broadcast (count , out , dim )
70+ out .div_ (count )
71+ return out
6272
6373
6474@torch .jit .script
0 commit comments