Skip to content

Commit d80cd24

Browse files
committed
broadcasting capabilities
1 parent d847087 commit d80cd24

File tree

5 files changed

+48
-2
lines changed

5 files changed

+48
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ The package consists of the following operations:
3535
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
3636
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
3737

38-
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation.
38+
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
3939

4040
## Installation
4141

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ PyTorch Scatter Documentation
66
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
77
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
88

9-
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation.
9+
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
1010

1111
.. toctree::
1212
:glob:

test/test_broadcasting.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
import torch
3+
from torch_scatter import scatter_add
4+
5+
from .utils import devices
6+
7+
8+
@pytest.mark.parametrize('device', devices)
9+
def test_broadcasting(device):
10+
B, C, H, W = (4, 3, 8, 8)
11+
12+
src = torch.randn((B, C, H, W), device=device)
13+
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
14+
out = scatter_add(src, index, dim=2, dim_size=H)
15+
assert out.size() == (B, C, H, W)
16+
17+
src = torch.randn((B, 1, H, W), device=device)
18+
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
19+
out = scatter_add(src, index, dim=2, dim_size=H)
20+
assert out.size() == (B, C, H, W)
21+
22+
src = torch.randn((B, 1, H, W), device=device)
23+
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
24+
out = scatter_add(src, index, dim=2, dim_size=H)
25+
assert out.size() == (B, 1, H, W)
26+
27+
src = torch.randn((B, C, H, W), device=device)
28+
index = torch.randint(0, H, (H, )).to(device, torch.long)
29+
out = scatter_add(src, index, dim=2, dim_size=H)
30+
assert out.size() == (B, C, H, W)

torch_scatter/add.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
2424
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
2525
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
2626
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
27+
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
28+
do not match.
2729
2830
For one-dimensional tensors, the operation computes
2931

torch_scatter/utils/gen.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import division
2+
13
from itertools import repeat
24

35

@@ -16,6 +18,18 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1618
index_size[dim] = src.size(dim)
1719
index = index.view(index_size).expand_as(src)
1820

21+
# Broadcasting capabilties: Expand dimensions to match.
22+
if src.dim() != index.dim():
23+
raise ValueError(
24+
('Number of dimensions of src and index tensor do not match, '
25+
'got {} and {}').format(src.dim(), index.dim()))
26+
27+
expand_size = []
28+
for s, i in zip(src.size(), index.size()):
29+
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
30+
src = src.expand(expand_size)
31+
index = index.expand_as(src)
32+
1933
# Generate output tensor if not given.
2034
if out is None:
2135
out_size = list(src.size())

0 commit comments

Comments
 (0)