@@ -11,7 +11,7 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
1111 dim_size : Optional [int ] = None ) -> torch .Tensor :
1212 index = broadcast (index , src , dim )
1313 if out is None :
14- size = src .size ()
14+ size = list ( src .size () )
1515 if dim_size is not None :
1616 size [dim ] = dim_size
1717 elif index .numel () == 0 :
@@ -57,18 +57,18 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5757
5858
5959@torch .jit .script
60- def scatter_min (src : torch . Tensor , index : torch . Tensor , dim : int = - 1 ,
61- out : Optional [ torch .Tensor ] = None ,
62- dim_size : Optional [int ] = None
63- ) -> Tuple [torch .Tensor , torch .Tensor ]:
60+ def scatter_min (
61+ src : torch . Tensor , index : torch .Tensor , dim : int = - 1 ,
62+ out : Optional [torch . Tensor ] = None ,
63+ dim_size : Optional [ int ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
6464 return torch .ops .torch_scatter .scatter_min (src , index , dim , out , dim_size )
6565
6666
6767@torch .jit .script
68- def scatter_max (src : torch . Tensor , index : torch . Tensor , dim : int = - 1 ,
69- out : Optional [ torch .Tensor ] = None ,
70- dim_size : Optional [int ] = None
71- ) -> Tuple [torch .Tensor , torch .Tensor ]:
68+ def scatter_max (
69+ src : torch . Tensor , index : torch .Tensor , dim : int = - 1 ,
70+ out : Optional [torch . Tensor ] = None ,
71+ dim_size : Optional [ int ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
7272 return torch .ops .torch_scatter .scatter_max (src , index , dim , out , dim_size )
7373
7474
0 commit comments