|
| 1 | +from itertools import product |
| 2 | + |
1 | 3 | import pytest |
2 | 4 | import torch |
3 | | -from torch_scatter import scatter_add |
| 5 | +from torch_scatter import scatter |
4 | 6 |
|
5 | | -from .utils import devices |
| 7 | +from .utils import reductions, devices |
6 | 8 |
|
7 | 9 |
|
8 | | -@pytest.mark.parametrize('device', devices) |
9 | | -def test_broadcasting(device): |
| 10 | +@pytest.mark.parametrize('reduce,device', product(reductions, devices)) |
| 11 | +def test_broadcasting(reduce, device): |
10 | 12 | B, C, H, W = (4, 3, 8, 8) |
11 | 13 |
|
| 14 | + src = torch.randn((B, C, H, W), device=device) |
| 15 | + index = torch.randint(0, H, (H, )).to(device, torch.long) |
| 16 | + out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) |
| 17 | + assert out.size() == (B, C, H, W) |
| 18 | + |
12 | 19 | src = torch.randn((B, C, H, W), device=device) |
13 | 20 | index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long) |
14 | | - out = scatter_add(src, index, dim=2, dim_size=H) |
| 21 | + out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) |
15 | 22 | assert out.size() == (B, C, H, W) |
16 | 23 |
|
17 | 24 | src = torch.randn((B, C, H, W), device=device) |
18 | 25 | index = torch.randint(0, H, (H, )).to(device, torch.long) |
19 | | - out = scatter_add(src, index, dim=2, dim_size=H) |
| 26 | + out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) |
20 | 27 | assert out.size() == (B, C, H, W) |
0 commit comments