Skip to content

Commit 0e2ddfa

Browse files
committed
added view to storage + rename
1 parent 4dec4df commit 0e2ddfa

File tree

6 files changed

+50
-88
lines changed

6 files changed

+50
-88
lines changed

test/test_storage.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
122122
assert storage.row().tolist() == [0, 0, 1, 1]
123123
assert storage.col().tolist() == [0, 1, 0, 1]
124124
assert storage.value().tolist() == [1, 2, 3, 4]
125+
126+
127+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
128+
def test_sparse_reshape(dtype, device):
129+
row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
130+
storage = SparseStorage(row=row, col=col)
131+
132+
storage = storage.sparse_reshape(2, 8)
133+
assert storage.sparse_sizes() == (2, 8)
134+
assert storage.row().tolist() == [0, 0, 1, 1]
135+
assert storage.col().tolist() == [0, 5, 2, 7]
136+
137+
storage = storage.sparse_reshape(-1, 4)
138+
assert storage.sparse_sizes() == (4, 4)
139+
assert storage.row().tolist() == [0, 1, 2, 3]
140+
assert storage.col().tolist() == [0, 1, 2, 3]
141+
142+
storage = storage.sparse_reshape(2, -1)
143+
assert storage.sparse_sizes() == (2, 8)
144+
assert storage.row().tolist() == [0, 0, 1, 1]
145+
assert storage.col().tolist() == [0, 5, 2, 7]

test/test_view.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

torch_sparse/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from .convert import to_scipy, from_scipy # noqa
5656
from .coalesce import coalesce # noqa
5757
from .transpose import transpose # noqa
58-
from .view import view # noqa
5958
from .eye import eye # noqa
6059
from .spmm import spmm # noqa
6160
from .spspmm import spspmm # noqa
@@ -102,7 +101,6 @@
102101
'from_scipy',
103102
'coalesce',
104103
'transpose',
105-
'view',
106104
'eye',
107105
'spmm',
108106
'spspmm',

torch_sparse/storage.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,31 @@ def sparse_resize(self, sparse_sizes: Tuple[int, int]):
260260
colcount=colcount, csr2csc=self._csr2csc,
261261
csc2csr=self._csc2csr, is_sorted=True)
262262

263+
def sparse_reshape(self, num_rows: int, num_cols: int):
264+
assert num_rows > 0 or num_rows == -1
265+
assert num_cols > 0 or num_cols == -1
266+
assert num_rows > 0 or num_cols > 0
267+
268+
total = self.sparse_size(0) * self.sparse_size(1)
269+
270+
if num_rows == -1:
271+
num_rows = total // num_cols
272+
273+
if num_cols == -1:
274+
num_cols = total // num_rows
275+
276+
assert num_rows * num_cols == total
277+
278+
idx = self.sparse_size(1) * self.row() + self.col()
279+
280+
row = idx / num_cols
281+
col = idx % num_cols
282+
283+
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
284+
sparse_sizes=(num_rows, num_cols), rowcount=None,
285+
colptr=None, colcount=None, csr2csc=None,
286+
csc2csr=None, is_sorted=True)
287+
263288
def has_rowcount(self) -> bool:
264289
return self._rowcount is not None
265290

torch_sparse/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ def sparse_size(self, dim: int) -> int:
171171
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
172172
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
173173

174+
def sparse_reshape(self, num_rows: int, num_cols: int):
175+
return self.from_storage(
176+
self.storage.sparse_reshape(num_rows, num_cols))
177+
174178
def is_coalesced(self) -> bool:
175179
return self.storage.is_coalesced()
176180

torch_sparse/view.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)