@@ -11,6 +11,9 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
1111 csr2csc = src .storage ._csr2csc
1212 colptr = src .storage ._colptr
1313
14+ if value is not None :
15+ value = value .to (other .dtype )
16+
1417 if value is not None and value .requires_grad :
1518 row = src .storage .row ()
1619
@@ -35,6 +38,9 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
3538 csr2csc = src .storage ._csr2csc
3639 colptr = src .storage ._colptr
3740
41+ if value is not None :
42+ value = value .to (other .dtype )
43+
3844 if value is not None and value .requires_grad :
3945 row = src .storage .row ()
4046
@@ -51,12 +57,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
5157def spmm_min (src : SparseTensor ,
5258 other : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
5359 rowptr , col , value = src .csr ()
60+
61+ if value is not None :
62+ value = value .to (other .dtype )
63+
5464 return torch .ops .torch_sparse .spmm_min (rowptr , col , value , other )
5565
5666
5767def spmm_max (src : SparseTensor ,
5868 other : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
5969 rowptr , col , value = src .csr ()
70+
71+ if value is not None :
72+ value = value .to (other .dtype )
73+
6074 return torch .ops .torch_sparse .spmm_max (rowptr , col , value , other )
6175
6276
@@ -81,8 +95,8 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
8195 value = valueA
8296 if valueA is not None and valueA .dtype == torch .half :
8397 valueA = valueA .to (torch .float )
84- if valueB is not None and valueB . dtype == torch . half :
85- valueB = valueB .to (torch . float )
98+ if valueB is not None :
99+ valueB = valueB .to (valueA . dtype )
86100 M , K = src .sparse_size (0 ), other .sparse_size (1 )
87101 rowptrC , colC , valueC = torch .ops .torch_sparse .spspmm_sum (
88102 rowptrA , colA , valueA , rowptrB , colB , valueB , K )
0 commit comments