Skip to content

Commit 2317ff6

Browse files
committed
test_spspmm_2
1 parent 9f03468 commit 2317ff6

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

test/test_spspmm.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from torch_sparse import spspmm, SparseTensor
5+
from torch_sparse import spspmm, SparseTensor, transpose
66

77
from .utils import grad_dtypes, devices, tensor
88

@@ -19,6 +19,38 @@ def test_spspmm(dtype, device):
1919
assert valueC.tolist() == [8, 6, 8]
2020

2121

22+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
23+
def test_spspmm_2(dtype, device):
24+
row = torch.tensor(
25+
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
26+
device=device
27+
)
28+
col = torch.tensor(
29+
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
30+
device=device
31+
)
32+
value = torch.tensor(
33+
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
34+
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
35+
-2**-0.5, 2**-0.5, -2**-0.5],
36+
dtype=dtype, device=device
37+
)
38+
index = torch.stack([row, col])
39+
40+
m = value.new_zeros(10, 16)
41+
m[index[0], index[1]] = value
42+
43+
index_t, value_t = transpose(index, value, 10, 16)
44+
45+
index, value = spspmm(index, value, index_t, value_t, 10, 16, 10)
46+
47+
mask = value.abs() > 1e-4
48+
index, value = index[:, mask], value[mask]
49+
50+
assert index.tolist() == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
51+
assert value.tolist() == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
52+
53+
2254
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
2355
def test_sparse_tensor_spspmm(dtype, device):
2456
x = SparseTensor(

0 commit comments

Comments
 (0)