99
1010@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
1111def test_spspmm (dtype , device ):
12- if dtype == torch .half :
13- return # TODO
14-
1512 indexA = torch .tensor ([[0 , 0 , 1 , 2 , 2 ], [1 , 2 , 0 , 0 , 1 ]], device = device )
1613 valueA = tensor ([1 , 2 , 3 , 4 , 5 ], dtype , device )
1714 indexB = torch .tensor ([[0 , 2 ], [1 , 0 ]], device = device )
@@ -24,9 +21,6 @@ def test_spspmm(dtype, device):
2421
2522@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
2623def test_sparse_tensor_spspmm (dtype , device ):
27- if dtype == torch .half :
28- return # TODO
29-
3024 x = SparseTensor (
3125 row = torch .tensor (
3226 [0 , 1 , 1 , 1 , 2 , 3 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 7 , 8 , 8 , 9 , 9 ],
@@ -44,8 +38,8 @@ def test_sparse_tensor_spspmm(dtype, device):
4438 expected = torch .eye (10 , dtype = dtype , device = device )
4539
4640 out = x @ x .to_dense ().t ()
47- assert torch .allclose (out , expected , atol = 1e-7 )
41+ assert torch .allclose (out , expected , atol = 1e-2 )
4842
4943 out = x @ x .t ()
5044 out = out .to_dense ()
51- assert torch .allclose (out , expected , atol = 1e-7 )
45+ assert torch .allclose (out , expected , atol = 1e-2 )
0 commit comments