Skip to content

Commit 3c7253a

Browse files
committed
spspmm args
1 parent b2ba34b commit 3c7253a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch_sparse/spspmm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import spspmm_cuda
88

99

10-
class SpSpMM(torch.autograd.Function):
10+
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
1111
"""Matrix product of two sparse tensors. Both input sparse matrices need to
1212
be coalesced.
1313
@@ -23,7 +23,10 @@ class SpSpMM(torch.autograd.Function):
2323
2424
:rtype: (:class:`LongTensor`, :class:`Tensor`)
2525
"""
26+
return SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
27+
2628

29+
class SpSpMM(torch.autograd.Function):
2730
@staticmethod
2831
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
2932
indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
@@ -53,9 +56,6 @@ def backward(ctx, grad_indexC, grad_valueC):
5356
return None, grad_valueA, None, grad_valueB, None, None, None
5457

5558

56-
spspmm = SpSpMM.apply
57-
58-
5959
def mm(indexA, valueA, indexB, valueB, m, k, n):
6060
assert valueA.dtype == valueB.dtype
6161

0 commit comments

Comments
 (0)