Skip to content

Commit 93540a3

Browse files
committed
clean up test
1 parent 0485165 commit 93540a3

File tree

1 file changed

+14
-49
lines changed

1 file changed

+14
-49
lines changed

test/test_spspmm.py

Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from torch_sparse import spspmm, SparseTensor, transpose
5+
from torch_sparse import spspmm, SparseTensor
66

77
from .utils import grad_dtypes, devices, tensor
88

@@ -19,62 +19,27 @@ def test_spspmm(dtype, device):
1919
assert valueC.tolist() == [8, 6, 8]
2020

2121

22-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
23-
def test_spspmm_2(dtype, device):
24-
row = torch.tensor(
25-
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
26-
device=device
27-
)
28-
col = torch.tensor(
29-
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
30-
device=device
31-
)
32-
value = torch.tensor(
33-
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
34-
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
35-
-2**-0.5, 2**-0.5, -2**-0.5],
36-
dtype=dtype, device=device
37-
)
38-
index = torch.stack([row, col])
39-
40-
m = value.new_zeros(10, 16)
41-
m[index[0], index[1]] = value
42-
43-
index_t, value_t = transpose(index, value, 10, 16)
44-
45-
index, value = spspmm(index, value, index_t, value_t, 10, 16, 10)
46-
47-
mask = value.abs() > 1e-4
48-
index, value = index[:, mask], value[mask]
49-
50-
assert index.tolist() == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
51-
assert value.tolist() == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
52-
53-
5422
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
5523
def test_sparse_tensor_spspmm(dtype, device):
5624
x = SparseTensor(
5725
row=torch.tensor(
5826
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
59-
device=device
60-
),
27+
device=device),
6128
col=torch.tensor(
6229
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
63-
device=device
64-
),
65-
value=torch.tensor(
66-
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
67-
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
68-
-2**-0.5, 2**-0.5, -2**-0.5],
69-
dtype=dtype, device=device
70-
),
30+
device=device),
31+
value=torch.tensor([
32+
1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
33+
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
34+
-2**-0.5, 2**-0.5, -2**-0.5
35+
], dtype=dtype, device=device),
7136
)
7237

73-
i0 = torch.eye(10, dtype=dtype, device=device)
38+
expected = torch.eye(10, dtype=dtype, device=device)
7439

75-
i1 = x @ x.to_dense().t()
76-
assert torch.allclose(i0, i1)
40+
out = x @ x.to_dense().t()
41+
assert torch.allclose(out, expected, atol=1e-7)
7742

78-
i1 = x @ x.t()
79-
i1 = i1.to_dense()
80-
assert torch.allclose(i0, i1)
43+
out = x @ x.t()
44+
out = out.to_dense()
45+
assert torch.allclose(out, expected, atol=1e-7)

0 commit comments

Comments
 (0)