22
33import pytest
44import torch
5- from torch_sparse import spspmm , SparseTensor
5+ from torch_sparse import spspmm , SparseTensor , transpose
66
77from .utils import grad_dtypes , devices , tensor
88
@@ -19,6 +19,38 @@ def test_spspmm(dtype, device):
1919 assert valueC .tolist () == [8 , 6 , 8 ]
2020
2121
22+ @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
23+ def test_spspmm_2 (dtype , device ):
24+ row = torch .tensor (
25+ [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
26+ device = device
27+ )
28+ col = torch .tensor (
29+ [0 , 5 , 10 , 15 , 1 , 2 , 3 , 7 , 13 , 6 , 9 , 5 , 10 , 15 , 11 , 14 , 5 , 15 ],
30+ device = device
31+ )
32+ value = torch .tensor (
33+ [1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
34+ - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
35+ - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5 ],
36+ dtype = dtype , device = device
37+ )
38+ index = torch .stack ([row , col ])
39+
40+ m = value .new_zeros (10 , 16 )
41+ m [index [0 ], index [1 ]] = value
42+
43+ index_t , value_t = transpose (index , value , 10 , 16 )
44+
45+ index , value = spspmm (index , value , index_t , value_t , 10 , 16 , 10 )
46+
47+ mask = value .abs () > 1e-4
48+ index , value = index [:, mask ], value [mask ]
49+
50+ assert index .tolist () == [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]]
51+ assert value .tolist () == [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
52+
53+
2254@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
2355def test_sparse_tensor_spspmm (dtype , device ):
2456 x = SparseTensor (
0 commit comments