99@pytest .mark .parametrize ('dtype,device' , product (dtypes , devices ))
1010def test_eye (dtype , device ):
1111 mat = SparseTensor .eye (3 , dtype = dtype , device = device )
12- assert mat .storage . col (). device == device
12+ assert mat .device () == device
1313 assert mat .storage .sparse_sizes () == (3 , 3 )
1414 assert mat .storage .row ().tolist () == [0 , 1 , 2 ]
1515 assert mat .storage .rowptr ().tolist () == [0 , 1 , 2 , 3 ]
@@ -18,17 +18,17 @@ def test_eye(dtype, device):
1818 assert mat .storage .value ().dtype == dtype
1919 assert mat .storage .num_cached_keys () == 0
2020
21- mat = SparseTensor .eye (3 , has_value = False )
22- assert mat .storage . col (). device == device
21+ mat = SparseTensor .eye (3 , has_value = False , device = device )
22+ assert mat .device () == device
2323 assert mat .storage .sparse_sizes () == (3 , 3 )
2424 assert mat .storage .row ().tolist () == [0 , 1 , 2 ]
2525 assert mat .storage .rowptr ().tolist () == [0 , 1 , 2 , 3 ]
2626 assert mat .storage .col ().tolist () == [0 , 1 , 2 ]
2727 assert mat .storage .value () is None
2828 assert mat .storage .num_cached_keys () == 0
2929
30- mat = SparseTensor .eye (3 , 4 , fill_cache = True )
31- assert mat .storage . col (). device == device
30+ mat = SparseTensor .eye (3 , 4 , fill_cache = True , device = device )
31+ assert mat .device () == device
3232 assert mat .storage .sparse_sizes () == (3 , 4 )
3333 assert mat .storage .row ().tolist () == [0 , 1 , 2 ]
3434 assert mat .storage .rowptr ().tolist () == [0 , 1 , 2 , 3 ]
@@ -40,8 +40,8 @@ def test_eye(dtype, device):
4040 assert mat .storage .csr2csc ().tolist () == [0 , 1 , 2 ]
4141 assert mat .storage .csc2csr ().tolist () == [0 , 1 , 2 ]
4242
43- mat = SparseTensor .eye (4 , 3 , fill_cache = True )
44- assert mat .storage . col (). device == device
43+ mat = SparseTensor .eye (4 , 3 , fill_cache = True , device = device )
44+ assert mat .device () == device
4545 assert mat .storage .sparse_sizes () == (4 , 3 )
4646 assert mat .storage .row ().tolist () == [0 , 1 , 2 ]
4747 assert mat .storage .rowptr ().tolist () == [0 , 1 , 2 , 3 , 3 ]
0 commit comments