22
33import pytest
44import torch
5- from torch_sparse import spspmm
5+ from torch_sparse import spspmm , SparseTensor
66
77from .utils import grad_dtypes , devices , tensor
88
@@ -17,3 +17,32 @@ def test_spspmm(dtype, device):
1717 indexC , valueC = spspmm (indexA , valueA , indexB , valueB , 3 , 3 , 2 )
1818 assert indexC .tolist () == [[0 , 1 , 2 ], [0 , 1 , 1 ]]
1919 assert valueC .tolist () == [8 , 6 , 8 ]
20+
21+
22+ @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
23+ def test_sparse_tensor_spspmm (dtype , device ):
24+ x = SparseTensor (
25+ row = torch .tensor (
26+ [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
27+ device = device
28+ ),
29+ col = torch .tensor (
30+ [0 , 5 , 10 , 15 , 1 , 2 , 3 , 7 , 13 , 6 , 9 , 5 , 10 , 15 , 11 , 14 , 5 , 15 ],
31+ device = device
32+ ),
33+ value = torch .tensor (
34+ [1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
35+ - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
36+ - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5 ],
37+ dtype = dtype , device = device
38+ ),
39+ )
40+
41+ i0 = torch .eye (10 , dtype = dtype , device = device )
42+
43+ i1 = x @ x .to_dense ().t ()
44+ assert torch .allclose (i0 , i1 )
45+
46+ i1 = x @ x .t ()
47+ i1 = i1 .to_dense ()
48+ assert torch .allclose (i0 , i1 )
0 commit comments