Skip to content

Commit 64a8e2c

Browse files
committed
view
1 parent 57852a6 commit 64a8e2c

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

test/test_view.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from itertools import product
2+
3+
import pytest
4+
import torch
5+
from torch_sparse.tensor import SparseTensor
6+
from torch_sparse import view
7+
8+
from .utils import dtypes, devices, tensor
9+
10+
11+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
12+
def test_view_matrix(dtype, device):
13+
row = torch.tensor([0, 1, 1], device=device)
14+
col = torch.tensor([1, 0, 2], device=device)
15+
index = torch.stack([row, col], dim=0)
16+
value = tensor([1, 2, 3], dtype, device)
17+
18+
index, value = view(index, value, m=2, n=3, new_n=2)
19+
assert index.tolist() == [[0, 1, 2], [1, 1, 1]]
20+
assert value.tolist() == [1, 2, 3]
21+
22+
23+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
24+
def test_view_sparse_tensor(dtype, device):
25+
options = torch.tensor(0, dtype=dtype, device=device)
26+
27+
mat = SparseTensor.eye(4, options=options).view(2, 8)
28+
assert mat.storage.sparse_sizes() == (2, 8)
29+
assert mat.storage.row().tolist() == [0, 0, 1, 1]
30+
assert mat.storage.col().tolist() == [0, 5, 2, 7]
31+
assert mat.storage.value().tolist() == [1, 1, 1, 1]

torch_sparse/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from .convert import to_scipy, from_scipy # noqa
5656
from .coalesce import coalesce # noqa
5757
from .transpose import transpose # noqa
58+
from .view import view # noqa
5859
from .eye import eye # noqa
5960
from .spmm import spmm # noqa
6061
from .spspmm import spspmm # noqa
@@ -101,6 +102,7 @@
101102
'from_scipy',
102103
'coalesce',
103104
'transpose',
105+
'view',
104106
'eye',
105107
'spmm',
106108
'spspmm',

torch_sparse/view.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
3+
from torch_sparse.storage import SparseStorage
4+
from torch_sparse.tensor import SparseTensor
5+
6+
7+
def _view(src: SparseTensor, n: int, layout: str = 'csr') -> SparseTensor:
8+
row, col, value = src.coo()
9+
sparse_sizes = src.storage.sparse_sizes()
10+
11+
if sparse_sizes[0] * sparse_sizes[1] % n == 0:
12+
raise RuntimeError(
13+
f"shape '[-1, {n}]' is invalid for input of size {sparse_sizes[0] * sparse_sizes[1]}")
14+
15+
assert layout == 'csr' or layout == 'csc'
16+
17+
if layout == 'csr':
18+
idx = sparse_sizes[1] * row + col
19+
row = idx // n
20+
col = idx % n
21+
sparse_sizes = (sparse_sizes[0] * sparse_sizes[1] // n, n)
22+
if layout == 'csc':
23+
idx = sparse_sizes[0] * col + row
24+
row = idx % n
25+
col = idx // n
26+
sparse_sizes = (n, sparse_sizes[0] * sparse_sizes[1] // n)
27+
28+
storage = SparseStorage(
29+
row=row,
30+
rowptr=src.storage._rowptr,
31+
col=col,
32+
value=value,
33+
sparse_sizes=sparse_sizes,
34+
rowcount=src.storage._rowcount,
35+
colptr=src.storage._colptr,
36+
colcount=src.storage._colcount,
37+
csr2csc=src.storage._csr2csc,
38+
csc2csr=src.storage._csc2csr,
39+
is_sorted=True,
40+
)
41+
42+
return src.from_storage(storage)
43+
44+
45+
SparseTensor.view = lambda self, m, n: _view(self, n, layout='csr')
46+
47+
###############################################################################
48+
49+
50+
def view(index, value, m, n, new_n):
51+
assert m * n % new_n == 0
52+
53+
row, col = index
54+
idx = n * row + col
55+
row = idx // new_n
56+
col = idx % new_n
57+
58+
return torch.stack([row, col], dim=0), value

0 commit comments

Comments
 (0)