Skip to content

Commit 2147841

Browse files
committed
dim size fix for index.numel() == 0
1 parent 50e4214 commit 2147841

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torch_scatter/utils/gen.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from itertools import repeat
22

33

4+
def maybe_dim_size(index, dim_size=None):
5+
if dim_size is not None:
6+
return dim_size
7+
return index.max().item() + 1 if index.numel() > 0 else 0
8+
9+
410
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
511
dim = range(src.dim())[dim] # Get real dim value.
612

@@ -12,8 +18,8 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1218

1319
# Generate output tensor if not given.
1420
if out is None:
15-
dim_size = index.max().item() + 1 if dim_size is None else dim_size
1621
out_size = list(src.size())
22+
dim_size = maybe_dim_size(index, dim_size)
1723
out_size[dim] = dim_size
1824
out = src.new_full(out_size, fill_value)
1925

0 commit comments

Comments
 (0)