Skip to content

Commit 481e81f

Browse files
committed
fix view on zero-element tensors
1 parent b98d1f3 commit 481e81f

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torch_scatter/utils/gen.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from itertools import repeat
22

3+
import torch
4+
35

46
def maybe_dim_size(index, dim_size=None):
57
if dim_size is not None:
@@ -14,7 +16,10 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1416
if index.dim() == 1:
1517
index_size = list(repeat(1, src.dim()))
1618
index_size[dim] = src.size(dim)
17-
index = index.view(index_size).expand_as(src)
19+
if index.numel() > 0:
20+
index = index.view(index_size).expand_as(src)
21+
else: # PyTorch has a bug when view is used on zero-element tensors.
22+
index = src.new_empty(index_size, dtype=torch.long)
1823

1924
# Generate output tensor if not given.
2025
if out is None:

0 commit comments

Comments
 (0)