22
33import pytest
44import torch
5- from torch_sparse import spspmm
5+ from torch_sparse import spspmm , SparseTensor , transpose
66
77from .utils import grad_dtypes , devices , tensor
88
@@ -17,3 +17,64 @@ def test_spspmm(dtype, device):
1717 indexC , valueC = spspmm (indexA , valueA , indexB , valueB , 3 , 3 , 2 )
1818 assert indexC .tolist () == [[0 , 1 , 2 ], [0 , 1 , 1 ]]
1919 assert valueC .tolist () == [8 , 6 , 8 ]
20+
21+
22+ @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
23+ def test_spspmm_2 (dtype , device ):
24+ row = torch .tensor (
25+ [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
26+ device = device
27+ )
28+ col = torch .tensor (
29+ [0 , 5 , 10 , 15 , 1 , 2 , 3 , 7 , 13 , 6 , 9 , 5 , 10 , 15 , 11 , 14 , 5 , 15 ],
30+ device = device
31+ )
32+ value = torch .tensor (
33+ [1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
34+ - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
35+ - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5 ],
36+ dtype = dtype , device = device
37+ )
38+ index = torch .stack ([row , col ])
39+
40+ m = value .new_zeros (10 , 16 )
41+ m [index [0 ], index [1 ]] = value
42+
43+ index_t , value_t = transpose (index , value , 10 , 16 )
44+
45+ index , value = spspmm (index , value , index_t , value_t , 10 , 16 , 10 )
46+
47+ mask = value .abs () > 1e-4
48+ index , value = index [:, mask ], value [mask ]
49+
50+ assert index .tolist () == [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]]
51+ assert value .tolist () == [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
52+
53+
54+ @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
55+ def test_sparse_tensor_spspmm (dtype , device ):
56+ x = SparseTensor (
57+ row = torch .tensor (
58+ [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
59+ device = device
60+ ),
61+ col = torch .tensor (
62+ [0 , 5 , 10 , 15 , 1 , 2 , 3 , 7 , 13 , 6 , 9 , 5 , 10 , 15 , 11 , 14 , 5 , 15 ],
63+ device = device
64+ ),
65+ value = torch .tensor (
66+ [1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
67+ - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
68+ - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5 ],
69+ dtype = dtype , device = device
70+ ),
71+ )
72+
73+ i0 = torch .eye (10 , dtype = dtype , device = device )
74+
75+ i1 = x @ x .to_dense ().t ()
76+ assert torch .allclose (i0 , i1 )
77+
78+ i1 = x @ x .t ()
79+ i1 = i1 .to_dense ()
80+ assert torch .allclose (i0 , i1 )
0 commit comments