Skip to content

Commit 8e6635b

Browse files
authored
fix test (#340)
1 parent 003abd5 commit 8e6635b

File tree

8 files changed

+18
-20
lines changed

8 files changed

+18
-20
lines changed

test/__init__.py

Whitespace-only changes.

test/test_broadcasting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import pytest
44
import torch
55
from torch_scatter import scatter
6-
7-
from .utils import reductions, devices
6+
from torch_scatter.testing import devices, reductions
87

98

109
@pytest.mark.parametrize('reduce,device', product(reductions, devices))

test/test_gather.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import pytest
44
import torch
55
from torch.autograd import gradcheck
6-
from torch_scatter import gather_csr, gather_coo
7-
8-
from .utils import tensor, dtypes, devices
6+
from torch_scatter import gather_coo, gather_csr
7+
from torch_scatter.testing import devices, dtypes, tensor
98

109
tests = [
1110
{

test/test_multi_gpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import pytest
44
import torch
55
import torch_scatter
6-
7-
from .utils import reductions, tensor, dtypes
6+
from torch_scatter.testing import dtypes, reductions, tensor
87

98
tests = [
109
{

test/test_scatter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import torch
55
import torch_scatter
66
from torch.autograd import gradcheck
7-
8-
from .utils import devices, dtypes, reductions, tensor
7+
from torch_scatter.testing import devices, dtypes, reductions, tensor
98

109
reductions = reductions + ['mul']
1110

test/test_segment.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import pytest
44
import torch
5-
from torch.autograd import gradcheck
65
import torch_scatter
7-
8-
from .utils import reductions, tensor, dtypes, devices
6+
from torch.autograd import gradcheck
7+
from torch_scatter.testing import devices, dtypes, reductions, tensor
98

109
tests = [
1110
{

test/test_zero_tensors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import pytest
44
import torch
5-
from torch_scatter import scatter, segment_coo, gather_coo
6-
from torch_scatter import segment_csr, gather_csr
7-
8-
from .utils import reductions, tensor, grad_dtypes, devices
5+
from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo,
6+
segment_csr)
7+
from torch_scatter.testing import devices, grad_dtypes, reductions, tensor
98

109

1110
@pytest.mark.parametrize('reduce,dtype,device',
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
from typing import Any
2+
13
import torch
24

35
reductions = ['sum', 'add', 'mean', 'min', 'max']
46

5-
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
6-
torch.int, torch.long]
7+
dtypes = [
8+
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
9+
torch.long
10+
]
711
grad_dtypes = [torch.float, torch.double]
812

913
devices = [torch.device('cpu')]
1014
if torch.cuda.is_available():
11-
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
15+
devices += [torch.device('cuda:0')]
1216

1317

14-
def tensor(x, dtype, device):
18+
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
1519
return None if x is None else torch.tensor(x, device=device).to(dtype)

0 commit comments

Comments
 (0)