Skip to content

Commit 8b29061

Browse files
committed
fix sparse_reshape in PyTorch 1.4.0
1 parent 73a89ef commit 8b29061

File tree

2 files changed

+5
-91
lines changed

2 files changed

+5
-91
lines changed

test/test_padding.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

torch_sparse/storage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
99

10+
# FIXME: Remove once `/` on `LongTensors` is officially removed from PyTorch.
11+
warnings.filterwarnings("ignore", category=UserWarning)
12+
1013

1114
def get_layout(layout: Optional[str] = None) -> str:
1215
if layout is None:
@@ -277,8 +280,9 @@ def sparse_reshape(self, num_rows: int, num_cols: int):
277280

278281
idx = self.sparse_size(1) * self.row() + self.col()
279282

280-
row = idx // num_cols
283+
row = idx / num_cols
281284
col = idx % num_cols
285+
assert row.dtype == torch.long and col.dtype == torch.long
282286

283287
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
284288
sparse_sizes=(num_rows, num_cols), rowcount=None,

0 commit comments

Comments
 (0)