Skip to content

Commit 9732a51

Browse files
committed
torch sparse convert + transpose cleanup
1 parent 0b79077 commit 9732a51

File tree

6 files changed

+59
-44
lines changed

6 files changed

+59
-44
lines changed

test/test_convert.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
from torch_sparse import to_scipy, from_scipy
3+
from torch_sparse import to_torch_sparse, from_torch_sparse
4+
5+
6+
def test_convert_scipy():
7+
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
8+
value = torch.Tensor([1, 2, 4, 1, 3])
9+
N = 3
10+
11+
out = from_scipy(to_scipy(index, value, N, N))
12+
assert out[0].tolist() == index.tolist()
13+
assert out[1].tolist() == value.tolist()
14+
15+
16+
def test_convert_torch_sparse():
17+
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
18+
value = torch.Tensor([1, 2, 4, 1, 3])
19+
N = 3
20+
21+
out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
22+
assert out[0].tolist() == index.tolist()
23+
assert out[1].tolist() == value.tolist()

test/test_transpose.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,31 @@
22

33
import pytest
44
import torch
5-
from torch_sparse import transpose, transpose_matrix
5+
from torch_sparse import transpose
66

77
from .utils import dtypes, devices, tensor
88

99

10-
def test_transpose():
11-
row = torch.tensor([1, 0, 1, 0, 2, 1])
12-
col = torch.tensor([0, 1, 1, 1, 0, 0])
10+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
11+
def test_transpose_matrix(dtype, device):
12+
row = torch.tensor([1, 0, 1, 2], device=device)
13+
col = torch.tensor([0, 1, 1, 0], device=device)
1314
index = torch.stack([row, col], dim=0)
14-
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
15+
value = tensor([1, 2, 3, 4], dtype, device)
1516

1617
index, value = transpose(index, value, m=3, n=2)
1718
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
18-
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
19+
assert value.tolist() == [1, 4, 2, 3]
1920

2021

2122
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
22-
def test_transpose_matrix(dtype, device):
23-
row = torch.tensor([1, 0, 1, 2], device=device)
24-
col = torch.tensor([0, 1, 1, 0], device=device)
23+
def test_transpose(dtype, device):
24+
row = torch.tensor([1, 0, 1, 0, 2, 1], device=device)
25+
col = torch.tensor([0, 1, 1, 1, 0, 0], device=device)
2526
index = torch.stack([row, col], dim=0)
26-
value = tensor([1, 2, 3, 4], dtype, device)
27+
value = tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype,
28+
device)
2729

28-
index, value = transpose_matrix(index, value, m=3, n=2)
30+
index, value = transpose(index, value, m=3, n=2)
2931
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
30-
assert value.tolist() == [1, 4, 2, 3]
32+
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]

torch_sparse/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .convert import to_scipy, from_scipy
1+
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
22
from .coalesce import coalesce
3-
from .transpose import transpose, transpose_matrix
3+
from .transpose import transpose
44
from .eye import eye
55
from .spmm import spmm
66
from .spspmm import spspmm
@@ -9,11 +9,12 @@
99

1010
__all__ = [
1111
'__version__',
12+
'to_torch_sparse',
13+
'from_torch_sparse',
1214
'to_scipy',
1315
'from_scipy',
1416
'coalesce',
1517
'transpose',
16-
'transpose_matrix',
1718
'eye',
1819
'spmm',
1920
'spspmm',

torch_sparse/convert.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from torch import from_numpy
55

66

7+
def to_torch_sparse(index, value, m, n):
8+
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n]))
9+
10+
11+
def from_torch_sparse(A):
12+
return A.indices().detach(), A.values()
13+
14+
715
def to_scipy(index, value, m, n):
816
assert not index.is_cuda and not value.is_cuda
917
(row, col), data = index.detach(), value.detach()

torch_sparse/spspmm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torch_sparse import transpose_matrix, to_scipy, from_scipy
2+
from torch_sparse import transpose, to_scipy, from_scipy
33

44
import torch_sparse.spspmm_cpu
55

@@ -53,9 +53,8 @@ def backward(ctx, grad_indexC, grad_valueC):
5353
valueB, m, k)
5454

5555
if ctx.needs_input_grad[3]:
56-
indexA, valueA = transpose_matrix(indexA, valueA, m, k)
57-
indexC, grad_valueC = transpose_matrix(indexC, grad_valueC, m,
58-
n)
56+
indexA, valueA = transpose(indexA, valueA, m, k)
57+
indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
5958
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
6059
indexB, indexA.detach(), valueA, indexC.detach(),
6160
grad_valueC, k, n)
@@ -66,7 +65,7 @@ def backward(ctx, grad_indexC, grad_valueC):
6665
indexB.detach(), valueB, m, k)
6766

6867
if ctx.needs_input_grad[3]:
69-
indexA_T, valueA_T = transpose_matrix(indexA, valueA, m, k)
68+
indexA_T, valueA_T = transpose(indexA, valueA, m, k)
7069
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
7170
grad_valueC, k, m, n)
7271
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)

torch_sparse/transpose.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,13 @@ def transpose(index, value, m, n):
1414
:rtype: (:class:`LongTensor`, :class:`Tensor`)
1515
"""
1616

17-
row, col = index
18-
index = torch.stack([col, row], dim=0)
19-
index, value = coalesce(index, value, n, m)
20-
return index, value
21-
22-
23-
def transpose_matrix(index, value, m, n):
24-
"""Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
25-
one-dimensional.
26-
27-
Args:
28-
index (:class:`LongTensor`): The index tensor of sparse matrix.
29-
value (:class:`Tensor`): The value tensor of sparse matrix.
30-
m (int): The first dimension of sparse matrix.
31-
n (int): The second dimension of sparse matrix.
32-
33-
:rtype: (:class:`LongTensor`, :class:`Tensor`)
34-
"""
35-
36-
assert value.dim() == 1
37-
38-
if index.is_cuda:
39-
return transpose(index, value, m, n)
40-
else:
17+
if value.dim() == 1 and not value.is_cuda:
4118
mat = to_scipy(index, value, m, n).tocsc()
4219
(col, row), value = from_scipy(mat)
4320
index = torch.stack([row, col], dim=0)
4421
return index, value
22+
23+
row, col = index
24+
index = torch.stack([col, row], dim=0)
25+
index, value = coalesce(index, value, n, m)
26+
return index, value

0 commit comments

Comments
 (0)