Skip to content

Commit 35b09b7

Browse files
committed
fix device eye test
1 parent d474adf commit 35b09b7

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

test/test_eye.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
1010
def test_eye(dtype, device):
1111
mat = SparseTensor.eye(3, dtype=dtype, device=device)
12-
assert mat.storage.col().device == device
12+
assert mat.device() == device
1313
assert mat.storage.sparse_sizes() == (3, 3)
1414
assert mat.storage.row().tolist() == [0, 1, 2]
1515
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
@@ -18,17 +18,17 @@ def test_eye(dtype, device):
1818
assert mat.storage.value().dtype == dtype
1919
assert mat.storage.num_cached_keys() == 0
2020

21-
mat = SparseTensor.eye(3, has_value=False)
22-
assert mat.storage.col().device == device
21+
mat = SparseTensor.eye(3, has_value=False, device=device)
22+
assert mat.device() == device
2323
assert mat.storage.sparse_sizes() == (3, 3)
2424
assert mat.storage.row().tolist() == [0, 1, 2]
2525
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
2626
assert mat.storage.col().tolist() == [0, 1, 2]
2727
assert mat.storage.value() is None
2828
assert mat.storage.num_cached_keys() == 0
2929

30-
mat = SparseTensor.eye(3, 4, fill_cache=True)
31-
assert mat.storage.col().device == device
30+
mat = SparseTensor.eye(3, 4, fill_cache=True, device=device)
31+
assert mat.device() == device
3232
assert mat.storage.sparse_sizes() == (3, 4)
3333
assert mat.storage.row().tolist() == [0, 1, 2]
3434
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
@@ -40,8 +40,8 @@ def test_eye(dtype, device):
4040
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
4141
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
4242

43-
mat = SparseTensor.eye(4, 3, fill_cache=True)
44-
assert mat.storage.col().device == device
43+
mat = SparseTensor.eye(4, 3, fill_cache=True, device=device)
44+
assert mat.device() == device
4545
assert mat.storage.sparse_sizes() == (4, 3)
4646
assert mat.storage.row().tolist() == [0, 1, 2]
4747
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]

0 commit comments

Comments
 (0)