11import torch
2- from torch import from_numpy
3- import numpy as np
4- import scipy .sparse
5- from torch_sparse import transpose
2+ from torch_sparse import transpose_matrix , to_scipy , from_scipy
3+
4+ import torch_sparse .spspmm_cpu
65
76if torch .cuda .is_available ():
87 import torch_sparse .spspmm_cuda
@@ -38,22 +37,39 @@ def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
3837
3938 @staticmethod
4039 def backward (ctx , grad_indexC , grad_valueC ):
41- m , k , n = ctx .m , ctx .k , ctx .n
40+ m , k = ctx .m , ctx .k
41+ n = ctx .n
4242 indexA , valueA , indexB , valueB , indexC = ctx .saved_tensors
4343
4444 grad_valueA = grad_valueB = None
4545
46- if ctx .needs_input_grad [1 ]:
47- indexB_T , valueB_T = transpose (indexB , valueB , k , n )
48- grad_indexA , grad_valueA = mm (indexC , grad_valueC , indexB_T ,
49- valueB_T , m , n , k )
50- grad_valueA = lift (grad_indexA , grad_valueA , indexA , k )
51-
52- if ctx .needs_input_grad [3 ]:
53- indexA_T , valueA_T = transpose (indexA , valueA , m , k )
54- grad_indexB , grad_valueB = mm (indexA_T , valueA_T , indexC ,
55- grad_valueC , k , m , n )
56- grad_valueB = lift (grad_indexB , grad_valueB , indexB , n )
46+ if not grad_valueC .is_cuda :
47+ if ctx .needs_input_grad [1 ] or ctx .needs_input_grad [1 ]:
48+ grad_valueC = grad_valueC .clone ()
49+
50+ if ctx .needs_input_grad [1 ]:
51+ grad_valueA = torch_sparse .spspmm_cpu .spspmm_bw (
52+ indexA , indexC .detach (), grad_valueC , indexB .detach (),
53+ valueB , m , k )
54+
55+ if ctx .needs_input_grad [3 ]:
56+ indexA , valueA = transpose_matrix (indexA , valueA , m , k )
57+ indexC , grad_valueC = transpose_matrix (indexC , grad_valueC , m ,
58+ n )
59+ grad_valueB = torch_sparse .spspmm_cpu .spspmm_bw (
60+ indexB , indexA .detach (), valueA , indexC .detach (),
61+ grad_valueC , k , n )
62+ else :
63+ if ctx .needs_input_grad [1 ]:
64+ grad_valueA = torch_sparse .spspmm_cuda .spspmm_bw (
65+ indexA , indexC .detach (), grad_valueC .clone (),
66+ indexB .detach (), valueB , m , k )
67+
68+ if ctx .needs_input_grad [3 ]:
69+ indexA_T , valueA_T = transpose_matrix (indexA , valueA , m , k )
70+ grad_indexB , grad_valueB = mm (indexA_T , valueA_T , indexC ,
71+ grad_valueC , k , m , n )
72+ grad_valueB = lift (grad_indexB , grad_valueB , indexB , n )
5773
5874 return None , grad_valueA , None , grad_valueB , None , None , None
5975
@@ -67,23 +83,11 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
6783
6884 A = to_scipy (indexA , valueA , m , k )
6985 B = to_scipy (indexB , valueB , k , n )
70- indexC , valueC = from_scipy ( A . tocsr (). dot ( B .tocsr ()) .tocoo ())
71-
86+ C = A . dot ( B ). tocoo () .tocsr ().tocoo () # Force coalesce.
87+ indexC , valueC = from_scipy ( C )
7288 return indexC , valueC
7389
7490
75- def to_scipy (index , value , m , n ):
76- (row , col ), data = index .detach (), value .detach ()
77- return scipy .sparse .coo_matrix ((data , (row , col )), (m , n ))
78-
79-
80- def from_scipy (A ):
81- row , col , value = A .row .astype (np .int64 ), A .col .astype (np .int64 ), A .data
82- row , col , value = from_numpy (row ), from_numpy (col ), from_numpy (value )
83- index = torch .stack ([row , col ], dim = 0 )
84- return index , value
85-
86-
8791def lift (indexA , valueA , indexB , n ): # pragma: no cover
8892 idxA = indexA [0 ] * n + indexA [1 ]
8993 idxB = indexB [0 ] * n + indexB [1 ]
0 commit comments