44
55import torch
66
7+ from .utils import broadcast
8+
79try :
810 torch .ops .load_library (
911 osp .join (osp .dirname (osp .abspath (__file__ )), '_scatter.so' ))
@@ -23,7 +25,6 @@ def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
2325 raise ImportError
2426 return src , index
2527
26- torch .ops .torch_scatter .scatter_sum = scatter_placeholder
2728 torch .ops .torch_scatter .scatter_mean = scatter_placeholder
2829 torch .ops .torch_scatter .scatter_min = scatter_with_arg_placeholder
2930 torch .ops .torch_scatter .scatter_max = scatter_with_arg_placeholder
@@ -33,14 +34,24 @@ def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
3334def scatter_sum (src : torch .Tensor , index : torch .Tensor , dim : int = - 1 ,
3435 out : Optional [torch .Tensor ] = None ,
3536 dim_size : Optional [int ] = None ) -> torch .Tensor :
36- return torch .ops .torch_scatter .scatter_sum (src , index , dim , out , dim_size )
37+ index = broadcast (index , src , dim )
38+ if out is None :
39+ size = src .size ()
40+ if dim_size is None :
41+ size [dim ] = int (index .max ()) + 1
42+ else :
43+ size [dim ] = dim_size
44+ out = src .new_zeros (size )
45+ return out .scatter_add_ (dim , index , src )
46+ else :
47+ return out .scatter_add_ (dim , index , src )
3748
3849
3950@torch .jit .script
4051def scatter_add (src : torch .Tensor , index : torch .Tensor , dim : int = - 1 ,
4152 out : Optional [torch .Tensor ] = None ,
4253 dim_size : Optional [int ] = None ) -> torch .Tensor :
43- return torch . ops . torch_scatter . scatter_sum (src , index , dim , out , dim_size )
54+ return scatter_sum (src , index , dim , out , dim_size )
4455
4556
4657@torch .jit .script
0 commit comments