Skip to content

Commit fca6819

Browse files
committed
fix test on pytorch 1.6.0
1 parent 947e036 commit fca6819

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

test/test_matmul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def test_spmm_half_precision():
5151
src_dense[:, 2:4] = 0 # Remove multiple columns.
5252
src = SparseTensor.from_dense(src_dense)
5353

54-
other = torch.randn((2, 8, 2), dtype=torch.half, device='cpu')
54+
other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')
5555

56-
expected = src_dense @ other
57-
out = src @ other
56+
expected = (src_dense.to(torch.float) @ other).to(torch.half)
57+
out = src @ other.to(torch.half)
5858

59-
assert torch.allclose(expected, out, atol=1e-6)
59+
assert torch.allclose(expected, out, atol=1e-2)
6060

6161

6262
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))

0 commit comments

Comments
 (0)