Skip to content

Commit 9f03468

Browse files
committed
test_sparse_tensor_spspmm
1 parent 57852a6 commit 9f03468

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

test/test_spspmm.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from torch_sparse import spspmm
5+
from torch_sparse import spspmm, SparseTensor
66

77
from .utils import grad_dtypes, devices, tensor
88

@@ -17,3 +17,32 @@ def test_spspmm(dtype, device):
1717
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
1818
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
1919
assert valueC.tolist() == [8, 6, 8]
20+
21+
22+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
23+
def test_sparse_tensor_spspmm(dtype, device):
24+
x = SparseTensor(
25+
row=torch.tensor(
26+
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
27+
device=device
28+
),
29+
col=torch.tensor(
30+
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
31+
device=device
32+
),
33+
value=torch.tensor(
34+
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
35+
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
36+
-2**-0.5, 2**-0.5, -2**-0.5],
37+
dtype=dtype, device=device
38+
),
39+
)
40+
41+
i0 = torch.eye(10, dtype=dtype, device=device)
42+
43+
i1 = x @ x.to_dense().t()
44+
assert torch.allclose(i0, i1)
45+
46+
i1 = x @ x.t()
47+
i1 = i1.to_dense()
48+
assert torch.allclose(i0, i1)

0 commit comments

Comments
 (0)