Skip to content

Commit 7d5034b

Browse files
committed
2 parents b5c8953 + efac08d commit 7d5034b

File tree

6 files changed

+55
-3
lines changed

6 files changed

+55
-3
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ exclude_lines =
55
pragma: no cover
66
cuda
77
backward
8+
raise

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ The package consists of the following operations:
3535
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
3636
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
3737

38-
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation.
38+
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
3939

4040
## Installation
4141

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ PyTorch Scatter Documentation
66
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
77
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
88

9-
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation.
9+
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
1010

1111
.. toctree::
1212
:glob:

test/test_broadcasting.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
import torch
3+
from torch_scatter import scatter_add
4+
5+
from .utils import devices
6+
7+
8+
@pytest.mark.parametrize('device', devices)
9+
def test_broadcasting(device):
10+
B, C, H, W = (4, 3, 8, 8)
11+
12+
src = torch.randn((B, C, H, W), device=device)
13+
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
14+
out = scatter_add(src, index, dim=2, dim_size=H)
15+
assert out.size() == (B, C, H, W)
16+
17+
src = torch.randn((B, 1, H, W), device=device)
18+
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
19+
out = scatter_add(src, index, dim=2, dim_size=H)
20+
assert out.size() == (B, C, H, W)
21+
22+
src = torch.randn((B, 1, H, W), device=device)
23+
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
24+
out = scatter_add(src, index, dim=2, dim_size=H)
25+
assert out.size() == (B, 1, H, W)
26+
27+
src = torch.randn((B, C, H, W), device=device)
28+
index = torch.randint(0, H, (H, )).to(device, torch.long)
29+
out = scatter_add(src, index, dim=2, dim_size=H)
30+
assert out.size() == (B, C, H, W)

torch_scatter/add.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
2424
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
2525
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
2626
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
27+
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
28+
do not match.
2729
2830
For one-dimensional tensors, the operation computes
2931

torch_scatter/utils/gen.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from __future__ import division
2+
13
from itertools import repeat
24

5+
import torch
6+
37

48
def maybe_dim_size(index, dim_size=None):
59
if dim_size is not None:
@@ -14,7 +18,22 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1418
if index.dim() == 1:
1519
index_size = list(repeat(1, src.dim()))
1620
index_size[dim] = src.size(dim)
17-
index = index.view(index_size).expand_as(src)
21+
if index.numel() > 0:
22+
index = index.view(index_size).expand_as(src)
23+
else: # PyTorch has a bug when view is used on zero-element tensors.
24+
index = src.new_empty(index_size, dtype=torch.long)
25+
26+
# Broadcasting capabilties: Expand dimensions to match.
27+
if src.dim() != index.dim():
28+
raise ValueError(
29+
('Number of dimensions of src and index tensor do not match, '
30+
'got {} and {}').format(src.dim(), index.dim()))
31+
32+
expand_size = []
33+
for s, i in zip(src.size(), index.size()):
34+
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
35+
src = src.expand(expand_size)
36+
index = index.expand_as(src)
1837

1938
# Generate output tensor if not given.
2039
if out is None:

0 commit comments

Comments
 (0)