Skip to content

Commit 6a6e86c

Browse files
committed
2 parents 2219f43 + d33d29b commit 6a6e86c

File tree

17 files changed

+156
-155
lines changed

17 files changed

+156
-155
lines changed

.coveragerc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[report]
22
exclude_lines =
33
pragma: no cover
4-
def backward
54
cuda

cuda/matmul.cpp

Lines changed: 0 additions & 15 deletions
This file was deleted.

cuda/spspmm.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <torch/torch.h>
2+
3+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
4+
5+
std::tuple<at::Tensor, at::Tensor>
6+
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
7+
at::Tensor valueB, int m, int k, int n);
8+
9+
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
10+
at::Tensor indexB, at::Tensor valueB,
11+
int m, int k, int n) {
12+
CHECK_CUDA(indexA);
13+
CHECK_CUDA(valueA);
14+
CHECK_CUDA(indexB);
15+
CHECK_CUDA(valueB);
16+
return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
17+
}
18+
19+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20+
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
21+
}

cuda/matmul_cuda.cu renamed to cuda/spspmm_kernel.cu

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,32 @@ static void init_cusparse() {
2727
}
2828
}
2929

30-
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
30+
std::tuple<at::Tensor, at::Tensor>
31+
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
32+
at::Tensor valueB, int m, int k, int n) {
3133
init_cusparse();
3234

33-
auto m = A.size(0);
34-
auto k = A.size(1);
35-
auto n = B.size(1);
35+
indexA = indexA.contiguous();
36+
valueA = valueA.contiguous();
37+
indexB = indexB.contiguous();
38+
valueB = valueB.contiguous();
3639

37-
auto nnzA = A._nnz();
38-
auto nnzB = B._nnz();
40+
auto nnzA = valueA.size(0);
41+
auto nnzB = valueB.size(0);
3942

40-
auto valueA = A._values();
41-
auto indexA = A._indices().toType(at::kInt);
42-
auto row_ptrA = at::empty(indexA.type(), {m + 1});
43+
indexA = indexA.toType(at::kInt);
44+
indexB = indexB.toType(at::kInt);
45+
46+
// Convert A to CSR format.
47+
auto row_ptrA = at::empty(m + 1, indexA.type());
4348
cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
4449
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
4550
auto colA = indexA[1];
4651
cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int),
4752
cudaMemcpyHostToDevice);
4853

49-
auto valueB = B._values();
50-
auto indexB = B._indices().toType(at::kInt);
51-
auto row_ptrB = at::empty(indexB.type(), {k + 1});
54+
// Convert B to CSR format.
55+
auto row_ptrB = at::empty(k + 1, indexB.type());
5256
cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
5357
row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
5458
auto colB = indexB[1];
@@ -61,14 +65,14 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
6165
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
6266

6367
int nnzC;
64-
auto row_ptrC = at::empty(indexA.type(), {m + 1});
68+
auto row_ptrC = at::empty(m + 1, indexB.type());
6569
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
6670
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
6771
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
6872
row_ptrB.data<int>(), colB.data<int>(), descr,
6973
row_ptrC.data<int>(), &nnzC);
70-
auto colC = at::empty(indexA.type(), {nnzC});
71-
auto valueC = at::empty(valueA.type(), {nnzC});
74+
auto colC = at::empty(nnzC, indexA.type());
75+
auto valueC = at::empty(nnzC, valueA.type());
7276

7377
CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
7478
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
@@ -77,7 +81,7 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
7781
colB.data<int>(), descr, valueC.data<scalar_t>(),
7882
row_ptrC.data<int>(), colC.data<int>());
7983

80-
auto rowC = at::empty(indexA.type(), {nnzC});
84+
auto rowC = at::empty(nnzC, indexA.type());
8185
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
8286
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
8387

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
__version__ = '0.2.0'
66
url = 'https://github.com/rusty1s/pytorch_sparse'
77

8-
install_requires = ['numpy', 'scipy']
8+
install_requires = ['scipy']
99
setup_requires = ['pytest-runner']
1010
tests_require = ['pytest', 'pytest-cov']
1111
ext_modules = []
1212
cmdclass = {}
1313

1414
if torch.cuda.is_available():
1515
ext_modules += [
16-
CUDAExtension('matmul_cuda',
17-
['cuda/matmul.cpp', 'cuda/matmul_cuda.cu'])
16+
CUDAExtension('spspmm_cuda',
17+
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'])
1818
]
1919
cmdclass['build_ext'] = BuildExtension
2020

test/test_coalesce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ def test_coalesce():
88
index = torch.stack([row, col], dim=0)
99
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
1010

11-
index, value = coalesce(index, value, torch.Size([4, 2]))
11+
index, value = coalesce(index, value, m=3, n=2)
1212
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
1313
assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]]

test/test_matmul.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

test/test_spmm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
from torch_sparse import spmm
3+
4+
5+
def test_spmm():
6+
row = torch.tensor([0, 0, 1, 2, 2])
7+
col = torch.tensor([0, 2, 1, 0, 1])
8+
index = torch.stack([row, col], dim=0)
9+
value = torch.tensor([1, 2, 4, 1, 3])
10+
11+
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
12+
out = spmm(index, value, 3, matrix)
13+
assert out.tolist() == [[7, 16], [8, 20], [7, 19]]

test/test_spspmm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from itertools import product
2+
3+
import pytest
4+
import torch
5+
from torch_sparse import spspmm
6+
7+
from .utils import dtypes, devices, tensor
8+
9+
10+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
11+
def test_spspmm(dtype, device):
12+
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
13+
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
14+
sizeA = torch.Size([3, 3])
15+
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
16+
valueB = tensor([2, 4], dtype, device)
17+
sizeB = torch.Size([3, 2])
18+
19+
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
20+
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
21+
assert valueC.tolist() == [8, 6, 8]
22+
23+
A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device)
24+
A = A.to_dense().requires_grad_()
25+
B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device)
26+
B = B.to_dense().requires_grad_()
27+
torch.matmul(A, B).sum().backward()
28+
29+
valueA = valueA.requires_grad_()
30+
valueB = valueB.requires_grad_()
31+
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
32+
valueC.sum().backward()
33+
34+
assert valueA.grad.tolist() == A.grad[indexA[0], indexA[1]].tolist()
35+
assert valueB.grad.tolist() == B.grad[indexB[0], indexB[1]].tolist()

test/test_transpose.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
from torch_sparse import transpose
3+
4+
5+
def test_transpose():
6+
row = torch.tensor([1, 0, 1, 0, 2, 1])
7+
col = torch.tensor([0, 1, 1, 1, 0, 0])
8+
index = torch.stack([row, col], dim=0)
9+
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
10+
11+
index, value = transpose(index, value, m=3, n=2)
12+
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
13+
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]

0 commit comments

Comments
 (0)