Skip to content

Commit 87c88d9

Browse files
committed
enable autocast
1 parent bdd1ced commit 87c88d9

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

torch_sparse/matmul.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
1111
csr2csc = src.storage._csr2csc
1212
colptr = src.storage._colptr
1313

14+
if value is not None:
15+
value = value.to(other.dtype)
16+
1417
if value is not None and value.requires_grad:
1518
row = src.storage.row()
1619

@@ -35,6 +38,9 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
3538
csr2csc = src.storage._csr2csc
3639
colptr = src.storage._colptr
3740

41+
if value is not None:
42+
value = value.to(other.dtype)
43+
3844
if value is not None and value.requires_grad:
3945
row = src.storage.row()
4046

@@ -51,12 +57,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
5157
def spmm_min(src: SparseTensor,
5258
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5359
rowptr, col, value = src.csr()
60+
61+
if value is not None:
62+
value = value.to(other.dtype)
63+
5464
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
5565

5666

5767
def spmm_max(src: SparseTensor,
5868
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5969
rowptr, col, value = src.csr()
70+
71+
if value is not None:
72+
value = value.to(other.dtype)
73+
6074
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
6175

6276

@@ -81,8 +95,8 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
8195
value = valueA
8296
if valueA is not None and valueA.dtype == torch.half:
8397
valueA = valueA.to(torch.float)
84-
if valueB is not None and valueB.dtype == torch.half:
85-
valueB = valueB.to(torch.float)
98+
if valueB is not None:
99+
valueB = valueB.to(valueA.dtype)
86100
M, K = src.sparse_size(0), other.sparse_size(1)
87101
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
88102
rowptrA, colA, valueA, rowptrB, colB, valueB, K)

0 commit comments

Comments
 (0)