22
33import pytest
44import torch
5- from torch_sparse import spspmm , SparseTensor , transpose
5+ from torch_sparse import spspmm , SparseTensor
66
77from .utils import grad_dtypes , devices , tensor
88
@@ -19,62 +19,27 @@ def test_spspmm(dtype, device):
1919 assert valueC .tolist () == [8 , 6 , 8 ]
2020
2121
22- @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
23- def test_spspmm_2 (dtype , device ):
24- row = torch .tensor (
25- [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
26- device = device
27- )
28- col = torch .tensor (
29- [0 , 5 , 10 , 15 , 1 , 2 , 3 , 7 , 13 , 6 , 9 , 5 , 10 , 15 , 11 , 14 , 5 , 15 ],
30- device = device
31- )
32- value = torch .tensor (
33- [1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
34- - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
35- - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5 ],
36- dtype = dtype , device = device
37- )
38- index = torch .stack ([row , col ])
39-
40- m = value .new_zeros (10 , 16 )
41- m [index [0 ], index [1 ]] = value
42-
43- index_t , value_t = transpose (index , value , 10 , 16 )
44-
45- index , value = spspmm (index , value , index_t , value_t , 10 , 16 , 10 )
46-
47- mask = value .abs () > 1e-4
48- index , value = index [:, mask ], value [mask ]
49-
50- assert index .tolist () == [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]]
51- assert value .tolist () == [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
52-
53-
5422@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
5523def test_sparse_tensor_spspmm (dtype , device ):
5624 x = SparseTensor (
5725 row = torch .tensor (
5826 [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
59- device = device
60- ),
27+ device = device ),
6128 col = torch .tensor (
6229 [0 , 5 , 10 , 15 , 1 , 2 , 3 , 7 , 13 , 6 , 9 , 5 , 10 , 15 , 11 , 14 , 5 , 15 ],
63- device = device
64- ),
65- value = torch .tensor (
66- [1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
67- - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
68- - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5 ],
69- dtype = dtype , device = device
70- ),
30+ device = device ),
31+ value = torch .tensor ([
32+ 1 , 3 ** - 0.5 , 3 ** - 0.5 , 3 ** - 0.5 , 1 , 1 , 1 , - 2 ** - 0.5 , - 2 ** - 0.5 ,
33+ - 2 ** - 0.5 , - 2 ** - 0.5 , 6 ** - 0.5 , - 6 ** 0.5 / 3 , 6 ** - 0.5 , - 2 ** - 0.5 ,
34+ - 2 ** - 0.5 , 2 ** - 0.5 , - 2 ** - 0.5
35+ ], dtype = dtype , device = device ),
7136 )
7237
73- i0 = torch .eye (10 , dtype = dtype , device = device )
38+ expected = torch .eye (10 , dtype = dtype , device = device )
7439
75- i1 = x @ x .to_dense ().t ()
76- assert torch .allclose (i0 , i1 )
40+ out = x @ x .to_dense ().t ()
41+ assert torch .allclose (out , expected , atol = 1e-7 )
7742
78- i1 = x @ x .t ()
79- i1 = i1 .to_dense ()
80- assert torch .allclose (i0 , i1 )
43+ out = x @ x .t ()
44+ out = out .to_dense ()
45+ assert torch .allclose (out , expected , atol = 1e-7 )
0 commit comments