Skip to content

Commit 0485165

Browse files
authored
Merge pull request #63 from mariogeiger/test
add test that fails on my computer but should work
2 parents 2eba313 + 2317ff6 commit 0485165

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

test/test_spspmm.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from torch_sparse import spspmm
5+
from torch_sparse import spspmm, SparseTensor, transpose
66

77
from .utils import grad_dtypes, devices, tensor
88

@@ -17,3 +17,64 @@ def test_spspmm(dtype, device):
1717
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
1818
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
1919
assert valueC.tolist() == [8, 6, 8]
20+
21+
22+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
23+
def test_spspmm_2(dtype, device):
24+
row = torch.tensor(
25+
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
26+
device=device
27+
)
28+
col = torch.tensor(
29+
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
30+
device=device
31+
)
32+
value = torch.tensor(
33+
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
34+
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
35+
-2**-0.5, 2**-0.5, -2**-0.5],
36+
dtype=dtype, device=device
37+
)
38+
index = torch.stack([row, col])
39+
40+
m = value.new_zeros(10, 16)
41+
m[index[0], index[1]] = value
42+
43+
index_t, value_t = transpose(index, value, 10, 16)
44+
45+
index, value = spspmm(index, value, index_t, value_t, 10, 16, 10)
46+
47+
mask = value.abs() > 1e-4
48+
index, value = index[:, mask], value[mask]
49+
50+
assert index.tolist() == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
51+
assert value.tolist() == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
52+
53+
54+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
55+
def test_sparse_tensor_spspmm(dtype, device):
56+
x = SparseTensor(
57+
row=torch.tensor(
58+
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
59+
device=device
60+
),
61+
col=torch.tensor(
62+
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
63+
device=device
64+
),
65+
value=torch.tensor(
66+
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
67+
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
68+
-2**-0.5, 2**-0.5, -2**-0.5],
69+
dtype=dtype, device=device
70+
),
71+
)
72+
73+
i0 = torch.eye(10, dtype=dtype, device=device)
74+
75+
i1 = x @ x.to_dense().t()
76+
assert torch.allclose(i0, i1)
77+
78+
i1 = x @ x.t()
79+
i1 = i1.to_dense()
80+
assert torch.allclose(i0, i1)

0 commit comments

Comments
 (0)