Skip to content

Commit 56ec830

Browse files
committed
spspmm half and version up
1 parent 3e87af1 commit 56ec830

File tree

7 files changed

+13
-15
lines changed

7 files changed

+13
-15
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.0)
22
project(torchsparse)
33
set(CMAKE_CXX_STANDARD 14)
4-
set(TORCHSPARSE_VERSION 0.6.10)
4+
set(TORCHSPARSE_VERSION 0.6.11)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
77

conda/pytorch-sparse/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package:
22
name: pytorch-sparse
3-
version: 0.6.10
3+
version: 0.6.11
44

55
source:
66
path: ../..

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_extensions():
102102

103103
setup(
104104
name='torch_sparse',
105-
version='0.6.10',
105+
version='0.6.11',
106106
author='Matthias Fey',
107107
author_email='[email protected]',
108108
url='https://github.com/rusty1s/pytorch_sparse',

test/test_matmul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ def test_spmm(dtype, device, reduce):
4747

4848
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
4949
def test_spspmm(dtype, device):
50-
if dtype == torch.half:
51-
return # TODO
52-
5350
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
5451
device=device)
5552

test/test_spspmm.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99

1010
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1111
def test_spspmm(dtype, device):
12-
if dtype == torch.half:
13-
return # TODO
14-
1512
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
1613
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
1714
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
@@ -24,9 +21,6 @@ def test_spspmm(dtype, device):
2421

2522
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
2623
def test_sparse_tensor_spspmm(dtype, device):
27-
if dtype == torch.half:
28-
return # TODO
29-
3024
x = SparseTensor(
3125
row=torch.tensor(
3226
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
@@ -44,8 +38,8 @@ def test_sparse_tensor_spspmm(dtype, device):
4438
expected = torch.eye(10, dtype=dtype, device=device)
4539

4640
out = x @ x.to_dense().t()
47-
assert torch.allclose(out, expected, atol=1e-7)
41+
assert torch.allclose(out, expected, atol=1e-2)
4842

4943
out = x @ x.t()
5044
out = out.to_dense()
51-
assert torch.allclose(out, expected, atol=1e-7)
45+
assert torch.allclose(out, expected, atol=1e-2)

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__version__ = '0.6.10'
6+
__version__ = '0.6.11'
77

88
suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
99

torch_sparse/matmul.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,16 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
7878
assert src.sparse_size(1) == other.sparse_size(0)
7979
rowptrA, colA, valueA = src.csr()
8080
rowptrB, colB, valueB = other.csr()
81+
value = valueA
82+
if valueA is not None and valueA.dtype == torch.half:
83+
valueA = valueA.to(torch.float)
84+
if valueB is not None and valueB.dtype == torch.half:
85+
valueB = valueB.to(torch.float)
8186
M, K = src.sparse_size(0), other.sparse_size(1)
8287
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
8388
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
89+
if valueC is not None and value is not None:
90+
valueC = valueC.to(value.dtype)
8491
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
8592
sparse_sizes=(M, K), is_sorted=True)
8693

0 commit comments

Comments
 (0)