Skip to content

Commit 238efb1

Browse files
committed
faster spspmm backward + cleanup
1 parent 5586d7a commit 238efb1

File tree

9 files changed

+262
-44
lines changed

9 files changed

+262
-44
lines changed

cpu/spspmm.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include <torch/extension.h>
2+
3+
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
4+
auto zero = at::zeros(num_nodes, row.options());
5+
auto one = at::ones(row.size(0), row.options());
6+
return zero.scatter_add_(0, row, one);
7+
}
8+
9+
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
10+
int64_t num_nodes) {
11+
// Assert already coalesced input.
12+
row = degree(row, num_nodes).cumsum(0);
13+
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
14+
return std::make_tuple(row, col);
15+
}
16+
17+
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
18+
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
19+
size_t rowB_max) {
20+
21+
int64_t *index_data = index.data<int64_t>();
22+
auto value = at::zeros(index.size(1), valueA.options());
23+
24+
at::Tensor rowA, colA;
25+
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
26+
int64_t *rowA_data = rowA.data<int64_t>();
27+
int64_t *colA_data = colA.data<int64_t>();
28+
29+
at::Tensor rowB, colB;
30+
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
31+
int64_t *rowB_data = rowB.data<int64_t>();
32+
int64_t *colB_data = colB.data<int64_t>();
33+
34+
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
35+
scalar_t *value_data = value.data<scalar_t>();
36+
scalar_t *valueA_data = valueA.data<scalar_t>();
37+
scalar_t *valueB_data = valueB.data<scalar_t>();
38+
39+
for (int64_t e = 0; e < value.size(0); e++) {
40+
int64_t i = index_data[e], j = index_data[value.size(0) + e];
41+
42+
for (ptrdiff_t dA = rowA_data[i]; dA < rowA_data[i + 1]; dA++) {
43+
int64_t cA = colA_data[dA];
44+
45+
for (ptrdiff_t dB = rowB_data[j]; dB < rowB_data[j + 1]; dB++) {
46+
int64_t cB = colB_data[dB];
47+
48+
if (cA == cB) {
49+
value_data[e] += valueA_data[dA] * valueB_data[dB];
50+
}
51+
52+
if (cB >= cA) {
53+
break;
54+
}
55+
}
56+
}
57+
}
58+
});
59+
60+
return value;
61+
}
62+
63+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
64+
m.def("spspmm_bw", &spspmm_bw,
65+
"Sparse-Sparse Matrix Multiplication Backward (CPU)");
66+
}

cuda/spspmm.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,35 @@
44

55
std::tuple<at::Tensor, at::Tensor>
66
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
7-
at::Tensor valueB, int m, int k, int n);
7+
at::Tensor valueB, size_t m, size_t k, size_t n);
8+
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
9+
at::Tensor valueA, at::Tensor indexB,
10+
at::Tensor valueB, size_t rowA_max, size_t rowB_max);
811

912
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
1013
at::Tensor indexB, at::Tensor valueB,
11-
int m, int k, int n) {
14+
size_t m, size_t k, size_t n) {
1215
CHECK_CUDA(indexA);
1316
CHECK_CUDA(valueA);
1417
CHECK_CUDA(indexB);
1518
CHECK_CUDA(valueB);
1619
return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
1720
}
1821

22+
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
23+
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
24+
size_t rowB_max) {
25+
CHECK_CUDA(index);
26+
CHECK_CUDA(indexA);
27+
CHECK_CUDA(valueA);
28+
CHECK_CUDA(indexB);
29+
CHECK_CUDA(valueB);
30+
return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
31+
rowB_max);
32+
}
33+
1934
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2035
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
36+
m.def("spspmm_bw", &spspmm_bw,
37+
"Sparse-Sparse Matrix Multiplication Backward (CUDA)");
2138
}

cuda/spspmm_kernel.cu

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
#include <cusparse.h>
44

5+
#define THREADS 1024
6+
#define BLOCKS(N) (N + THREADS - 1) / THREADS
7+
58
#define CSRGEMM(TYPE, ...) \
69
[&] { \
710
const at::Type &the_type = TYPE; \
@@ -29,7 +32,7 @@ static void init_cusparse() {
2932

3033
std::tuple<at::Tensor, at::Tensor>
3134
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
32-
at::Tensor valueB, int m, int k, int n) {
35+
at::Tensor valueB, size_t m, size_t k, size_t n) {
3336
cudaSetDevice(indexA.get_device());
3437
init_cusparse();
3538

@@ -90,3 +93,69 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
9093

9194
return std::make_tuple(indexC, valueC);
9295
}
96+
97+
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
98+
auto zero = at::zeros(num_nodes, row.options());
99+
auto one = at::ones(row.size(0), row.options());
100+
return zero.scatter_add_(0, row, one);
101+
}
102+
103+
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
104+
int64_t num_nodes) {
105+
// Assert already coalesced input.
106+
row = degree(row, num_nodes).cumsum(0);
107+
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
108+
return std::make_tuple(row, col);
109+
}
110+
111+
template <typename scalar_t>
112+
__global__ void spspmm_bw_kernel(
113+
const int64_t *__restrict__ index, scalar_t *__restrict__ value,
114+
const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
115+
const scalar_t *__restrict__ valueA, const int64_t *__restrict__ rowB,
116+
const int64_t *__restrict__ colB, const scalar_t *__restrict__ valueB,
117+
const size_t numel) {
118+
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
119+
const size_t stride = blockDim.x * gridDim.x;
120+
for (ptrdiff_t e = idx; e < numel; e += stride) {
121+
int64_t i = index[e], j = index[numel + e];
122+
123+
for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
124+
int64_t cA = colA[dA];
125+
126+
for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
127+
int64_t cB = colB[dB];
128+
129+
if (cA == cB) {
130+
value[e] += valueA[dA] * valueB[dB];
131+
}
132+
133+
if (cB >= cA) {
134+
break;
135+
}
136+
}
137+
}
138+
}
139+
}
140+
141+
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
142+
at::Tensor valueA, at::Tensor indexB,
143+
at::Tensor valueB, size_t rowA_max, size_t rowB_max) {
144+
cudaSetDevice(index.get_device());
145+
auto value = at::zeros(index.size(1), valueA.options());
146+
147+
at::Tensor rowA, colA;
148+
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
149+
150+
at::Tensor rowB, colB;
151+
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
152+
153+
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
154+
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
155+
index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(),
156+
colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(),
157+
colB.data<int64_t>(), valueB.data<scalar_t>(), value.numel());
158+
});
159+
160+
return value;
161+
}

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import platform
22
from setuptools import setup, find_packages
3-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
3+
import torch
4+
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
45

5-
__version__ = '0.2.4'
6+
__version__ = '0.3.0'
67
url = 'https://github.com/rusty1s/pytorch_sparse'
78

89
install_requires = ['scipy']
910
setup_requires = ['pytest-runner']
1011
tests_require = ['pytest', 'pytest-cov']
11-
ext_modules = []
12+
ext_modules = [CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'])]
1213
cmdclass = {}
1314

1415
if CUDA_HOME is not None:
@@ -25,7 +26,7 @@
2526
CUDAExtension('torch_sparse.unique_cuda',
2627
['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
2728
]
28-
cmdclass['build_ext'] = BuildExtension
29+
cmdclass['build_ext'] = torch.utils.cpp_extension.BuildExtension
2930

3031
setup(
3132
name='torch_sparse',

test/test_transpose.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from itertools import product
2+
3+
import pytest
14
import torch
2-
from torch_sparse import transpose
5+
from torch_sparse import transpose, transpose_matrix
6+
7+
from .utils import dtypes, devices, tensor
38

49

510
def test_transpose():
@@ -11,3 +16,15 @@ def test_transpose():
1116
index, value = transpose(index, value, m=3, n=2)
1217
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
1318
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
19+
20+
21+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
22+
def test_transpose_matrix(dtype, device):
23+
row = torch.tensor([1, 0, 1, 2], device=device)
24+
col = torch.tensor([0, 1, 1, 0], device=device)
25+
index = torch.stack([row, col], dim=0)
26+
value = tensor([1, 2, 3, 4], dtype, device)
27+
28+
index, value = transpose_matrix(index, value, m=3, n=2)
29+
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
30+
assert value.tolist() == [1, 4, 2, 3]

torch_sparse/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
from .convert import to_scipy, from_scipy
12
from .coalesce import coalesce
2-
from .transpose import transpose
3+
from .transpose import transpose, transpose_matrix
34
from .eye import eye
45
from .spmm import spmm
56
from .spspmm import spspmm
67

7-
__version__ = '0.2.4'
8+
__version__ = '0.3.0'
89

910
__all__ = [
1011
'__version__',
12+
'to_scipy',
13+
'from_scipy',
1114
'coalesce',
1215
'transpose',
16+
'transpose_matrix',
1317
'eye',
1418
'spmm',
1519
'spspmm',

torch_sparse/convert.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
import scipy.sparse
3+
import torch
4+
from torch import from_numpy
5+
6+
7+
def to_scipy(index, value, m, n):
8+
assert not index.is_cuda and not value.is_cuda
9+
(row, col), data = index.detach(), value.detach()
10+
return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
11+
12+
13+
def from_scipy(A):
14+
A = A.tocoo()
15+
row, col, value = A.row.astype(np.int64), A.col.astype(np.int64), A.data
16+
row, col, value = from_numpy(row), from_numpy(col), from_numpy(value)
17+
index = torch.stack([row, col], dim=0)
18+
return index, value

torch_sparse/spspmm.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
2-
from torch import from_numpy
3-
import numpy as np
4-
import scipy.sparse
5-
from torch_sparse import transpose
2+
from torch_sparse import transpose_matrix, to_scipy, from_scipy
3+
4+
import torch_sparse.spspmm_cpu
65

76
if torch.cuda.is_available():
87
import torch_sparse.spspmm_cuda
@@ -38,22 +37,39 @@ def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
3837

3938
@staticmethod
4039
def backward(ctx, grad_indexC, grad_valueC):
41-
m, k, n = ctx.m, ctx.k, ctx.n
40+
m, k = ctx.m, ctx.k
41+
n = ctx.n
4242
indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
4343

4444
grad_valueA = grad_valueB = None
4545

46-
if ctx.needs_input_grad[1]:
47-
indexB_T, valueB_T = transpose(indexB, valueB, k, n)
48-
grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
49-
valueB_T, m, n, k)
50-
grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
51-
52-
if ctx.needs_input_grad[3]:
53-
indexA_T, valueA_T = transpose(indexA, valueA, m, k)
54-
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
55-
grad_valueC, k, m, n)
56-
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
46+
if not grad_valueC.is_cuda:
47+
if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
48+
grad_valueC = grad_valueC.clone()
49+
50+
if ctx.needs_input_grad[1]:
51+
grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
52+
indexA, indexC.detach(), grad_valueC, indexB.detach(),
53+
valueB, m, k)
54+
55+
if ctx.needs_input_grad[3]:
56+
indexA, valueA = transpose_matrix(indexA, valueA, m, k)
57+
indexC, grad_valueC = transpose_matrix(indexC, grad_valueC, m,
58+
n)
59+
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
60+
indexB, indexA.detach(), valueA, indexC.detach(),
61+
grad_valueC, k, n)
62+
else:
63+
if ctx.needs_input_grad[1]:
64+
grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
65+
indexA, indexC.detach(), grad_valueC.clone(),
66+
indexB.detach(), valueB, m, k)
67+
68+
if ctx.needs_input_grad[3]:
69+
indexA_T, valueA_T = transpose_matrix(indexA, valueA, m, k)
70+
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
71+
grad_valueC, k, m, n)
72+
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
5773

5874
return None, grad_valueA, None, grad_valueB, None, None, None
5975

@@ -67,23 +83,11 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
6783

6884
A = to_scipy(indexA, valueA, m, k)
6985
B = to_scipy(indexB, valueB, k, n)
70-
indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
71-
86+
C = A.dot(B).tocoo().tocsr().tocoo() # Force coalesce.
87+
indexC, valueC = from_scipy(C)
7288
return indexC, valueC
7389

7490

75-
def to_scipy(index, value, m, n):
76-
(row, col), data = index.detach(), value.detach()
77-
return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
78-
79-
80-
def from_scipy(A):
81-
row, col, value = A.row.astype(np.int64), A.col.astype(np.int64), A.data
82-
row, col, value = from_numpy(row), from_numpy(col), from_numpy(value)
83-
index = torch.stack([row, col], dim=0)
84-
return index, value
85-
86-
8791
def lift(indexA, valueA, indexB, n): # pragma: no cover
8892
idxA = indexA[0] * n + indexA[1]
8993
idxB = indexB[0] * n + indexB[1]

0 commit comments

Comments
 (0)