|
| 1 | +from itertools import product |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +from torch_sparse.tensor import SparseTensor |
| 6 | +from torch_sparse import view |
| 7 | + |
| 8 | +from .utils import dtypes, devices, tensor |
| 9 | + |
| 10 | + |
| 11 | +@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) |
| 12 | +def test_view_matrix(dtype, device): |
| 13 | + row = torch.tensor([0, 1, 1], device=device) |
| 14 | + col = torch.tensor([1, 0, 2], device=device) |
| 15 | + index = torch.stack([row, col], dim=0) |
| 16 | + value = tensor([1, 2, 3], dtype, device) |
| 17 | + |
| 18 | + index, value = view(index, value, m=2, n=3, new_n=2) |
| 19 | + assert index.tolist() == [[0, 1, 2], [1, 1, 1]] |
| 20 | + assert value.tolist() == [1, 2, 3] |
| 21 | + |
| 22 | + |
| 23 | +@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) |
| 24 | +def test_view_sparse_tensor(dtype, device): |
| 25 | + options = torch.tensor(0, dtype=dtype, device=device) |
| 26 | + |
| 27 | + mat = SparseTensor.eye(4, options=options).view(2, 8) |
| 28 | + assert mat.storage.sparse_sizes() == (2, 8) |
| 29 | + assert mat.storage.row().tolist() == [0, 0, 1, 1] |
| 30 | + assert mat.storage.col().tolist() == [0, 5, 2, 7] |
| 31 | + assert mat.storage.value().tolist() == [1, 1, 1, 1] |
0 commit comments