|
1 | | -from itertools import product |
2 | | - |
3 | | -import pytest |
4 | 1 | import torch |
5 | | -from torch_scatter.composite import scatter_log_softmax, scatter_softmax |
6 | | - |
7 | | -from test.utils import devices, tensor, grad_dtypes |
| 2 | +from torch_scatter import scatter_log_softmax, scatter_softmax |
8 | 3 |
|
9 | 4 |
|
10 | | -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) |
11 | | -def test_softmax(dtype, device): |
12 | | - src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) |
13 | | - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) |
| 5 | +def test_softmax(): |
| 6 | + src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) |
| 7 | + index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) |
14 | 8 |
|
15 | 9 | out = scatter_softmax(src, index) |
16 | 10 |
|
17 | | - out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) |
18 | | - out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) |
19 | | - out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1) |
20 | | - out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype), |
21 | | - dim=-1) |
| 11 | + out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1) |
| 12 | + out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) |
| 13 | + out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1) |
| 14 | + out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1) |
22 | 15 |
|
23 | 16 | expected = torch.stack([ |
24 | 17 | out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] |
25 | | - ], dim=0).to(device) |
| 18 | + ], dim=0) |
26 | 19 |
|
27 | 20 | assert torch.allclose(out, expected) |
28 | 21 |
|
29 | 22 |
|
30 | | -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) |
31 | | -def test_softmax_broadcasting(dtype, device): |
32 | | - src = torch.randn(10, 5, dtype=dtype, device=device) |
33 | | - index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) |
34 | | - |
35 | | - out = scatter_softmax(src, index, dim=0).view(5, 2, 5) |
36 | | - out = out.sum(dim=1) |
37 | | - assert torch.allclose(out, torch.ones_like(out)) |
38 | | - |
39 | | - |
40 | | -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) |
41 | | -def test_log_softmax(dtype, device): |
42 | | - src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) |
43 | | - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) |
| 23 | +def test_log_softmax(): |
| 24 | + src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) |
| 25 | + index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) |
44 | 26 |
|
45 | 27 | out = scatter_log_softmax(src, index) |
46 | 28 |
|
47 | | - out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) |
48 | | - out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) |
49 | | - out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1) |
50 | | - out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype), |
51 | | - dim=-1) |
| 29 | + out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1) |
| 30 | + out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) |
| 31 | + out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1) |
| 32 | + out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1) |
52 | 33 |
|
53 | 34 | expected = torch.stack([ |
54 | 35 | out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] |
55 | | - ], dim=0).to(device) |
| 36 | + ], dim=0) |
56 | 37 |
|
57 | 38 | assert torch.allclose(out, expected) |
0 commit comments