Skip to content

Commit 8ec6d0c

Browse files
committed
convert size to a list
1 parent 6a2cc2e commit 8ec6d0c

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

torch_scatter/scatter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
1111
dim_size: Optional[int] = None) -> torch.Tensor:
1212
index = broadcast(index, src, dim)
1313
if out is None:
14-
size = src.size()
14+
size = list(src.size())
1515
if dim_size is not None:
1616
size[dim] = dim_size
1717
elif index.numel() == 0:
@@ -57,18 +57,18 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5757

5858

5959
@torch.jit.script
60-
def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
61-
out: Optional[torch.Tensor] = None,
62-
dim_size: Optional[int] = None
63-
) -> Tuple[torch.Tensor, torch.Tensor]:
60+
def scatter_min(
61+
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
62+
out: Optional[torch.Tensor] = None,
63+
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
6464
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
6565

6666

6767
@torch.jit.script
68-
def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
69-
out: Optional[torch.Tensor] = None,
70-
dim_size: Optional[int] = None
71-
) -> Tuple[torch.Tensor, torch.Tensor]:
68+
def scatter_max(
69+
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
70+
out: Optional[torch.Tensor] = None,
71+
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
7272
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
7373

7474

0 commit comments

Comments
 (0)