Skip to content

Commit 7fd9091

Browse files
committed
update code and tests
1 parent 5be6d63 commit 7fd9091

37 files changed

+438
-1575
lines changed

.coveragerc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,5 @@ source=torch_scatter
33
[report]
44
exclude_lines =
55
pragma: no cover
6-
cuda
7-
forward
8-
backward
9-
apply
6+
torch.jit.script
107
raise
11-
min_value
12-
max_value

test/composite/test_logsumexp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from torch_scatter import scatter_logsumexp
3+
4+
5+
def test_logsumexp():
6+
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
7+
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
8+
9+
out = scatter_logsumexp(src, index)
10+
11+
out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1)
12+
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1)
13+
out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1)
14+
out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1)
15+
out4 = torch.tensor(-1, dtype=torch.float)
16+
17+
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
18+
assert torch.allclose(out, expected)

test/composite/test_softmax.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,38 @@
1-
from itertools import product
2-
3-
import pytest
41
import torch
5-
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
6-
7-
from test.utils import devices, tensor, grad_dtypes
2+
from torch_scatter import scatter_log_softmax, scatter_softmax
83

94

10-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
11-
def test_softmax(dtype, device):
12-
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
13-
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
5+
def test_softmax():
6+
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
7+
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
148

159
out = scatter_softmax(src, index)
1610

17-
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
18-
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
19-
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
20-
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
21-
dim=-1)
11+
out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
12+
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
13+
out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
14+
out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)
2215

2316
expected = torch.stack([
2417
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
25-
], dim=0).to(device)
18+
], dim=0)
2619

2720
assert torch.allclose(out, expected)
2821

2922

30-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
31-
def test_softmax_broadcasting(dtype, device):
32-
src = torch.randn(10, 5, dtype=dtype, device=device)
33-
index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
34-
35-
out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
36-
out = out.sum(dim=1)
37-
assert torch.allclose(out, torch.ones_like(out))
38-
39-
40-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
41-
def test_log_softmax(dtype, device):
42-
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
43-
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
23+
def test_log_softmax():
24+
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
25+
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
4426

4527
out = scatter_log_softmax(src, index)
4628

47-
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
48-
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
49-
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
50-
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
51-
dim=-1)
29+
out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
30+
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
31+
out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
32+
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
5233

5334
expected = torch.stack([
5435
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
55-
], dim=0).to(device)
36+
], dim=0)
5637

5738
assert torch.allclose(out, expected)

test/composite/test_std.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
from torch_scatter import scatter_std
3+
4+
5+
def test_std():
6+
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
7+
index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)
8+
9+
out = scatter_std(src, index, dim=-1, unbiased=True)
10+
std = src.std(dim=-1, unbiased=True)[0]
11+
expected = torch.tensor([[std, 0], [0, std]])
12+
assert torch.allclose(out, expected)

test/test_backward.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

test/test_broadcasting.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,6 @@ def test_broadcasting(device):
1414
out = scatter_add(src, index, dim=2, dim_size=H)
1515
assert out.size() == (B, C, H, W)
1616

17-
src = torch.randn((B, 1, H, W), device=device)
18-
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
19-
out = scatter_add(src, index, dim=2, dim_size=H)
20-
assert out.size() == (B, C, H, W)
21-
22-
src = torch.randn((B, 1, H, W), device=device)
23-
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
24-
out = scatter_add(src, index, dim=2, dim_size=H)
25-
assert out.size() == (B, 1, H, W)
26-
2717
src = torch.randn((B, C, H, W), device=device)
2818
index = torch.randint(0, H, (H, )).to(device, torch.long)
2919
out = scatter_add(src, index, dim=2, dim_size=H)

test/test_forward.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

test/test_logsumexp.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

test/test_max_min.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)