@@ -76,7 +76,8 @@ def spmm_max(src: SparseTensor,
7676 return torch .ops .torch_sparse .spmm_max (rowptr , col , value , other )
7777
7878
79- def spmm (src : SparseTensor , other : torch .Tensor ,
79+ def spmm (src : SparseTensor ,
80+ other : torch .Tensor ,
8081 reduce : str = "sum" ) -> torch .Tensor :
8182 if reduce == 'sum' or reduce == 'add' :
8283 return spmm_sum (src , other )
@@ -97,7 +98,7 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
9798 edge_index = C ._indices ()
9899 row , col = edge_index [0 ], edge_index [1 ]
99100 value : Optional [Tensor ] = None
100- if src .has_value () and other .has_value ():
101+ if src .has_value () or other .has_value ():
101102 value = C ._values ()
102103
103104 return SparseTensor (
@@ -114,7 +115,8 @@ def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
114115 return spspmm_sum (src , other )
115116
116117
117- def spspmm (src : SparseTensor , other : SparseTensor ,
118+ def spspmm (src : SparseTensor ,
119+ other : SparseTensor ,
118120 reduce : str = "sum" ) -> SparseTensor :
119121 if reduce == 'sum' or reduce == 'add' :
120122 return spspmm_sum (src , other )
0 commit comments