Skip to content

Commit ff28536

Browse files
committed
remove @torch.jit.script annotations (move jit compatibility to the test suite)
1 parent 3341dbe commit ff28536

File tree

12 files changed

+54
-47
lines changed

12 files changed

+54
-47
lines changed

test/composite/test_logsumexp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ def test_logsumexp():
1818
assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
1919

2020
outputs.backward(torch.randn_like(outputs))
21+
22+
jit = torch.jit.script(scatter_logsumexp)
23+
assert jit(inputs, index).tolist() == outputs.tolist()

test/composite/test_softmax.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def test_softmax():
2222

2323
out.backward(torch.randn_like(out))
2424

25+
jit = torch.jit.script(scatter_softmax)
26+
assert jit(src, index).tolist() == out.tolist()
27+
2528

2629
def test_log_softmax():
2730
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
@@ -42,3 +45,6 @@ def test_log_softmax():
4245
assert torch.allclose(out, expected)
4346

4447
out.backward(torch.randn_like(out))
48+
49+
jit = torch.jit.script(scatter_log_softmax)
50+
assert jit(src, index).tolist() == out.tolist()

test/composite/test_std.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ def test_std():
1313
assert torch.allclose(out, expected)
1414

1515
out.backward(torch.randn_like(out))
16+
17+
jit = torch.jit.script(scatter_std)
18+
assert jit(src, index, dim=-1, unbiased=True).tolist() == out.tolist()

test/test_scatter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,18 @@ def test_forward(test, reduce, dtype, device):
9999
dim = test['dim']
100100
expected = tensor(test[reduce], dtype, device)
101101

102-
out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim)
103-
if isinstance(out, tuple):
104-
out, arg_out = out
102+
fn = getattr(torch_scatter, 'scatter_' + reduce)
103+
jit = torch.jit.script(fn)
104+
out1 = fn(src, index, dim)
105+
out2 = jit(src, index, dim)
106+
if isinstance(out1, tuple):
107+
out1, arg_out1 = out1
108+
out2, arg_out2 = out2
105109
arg_expected = tensor(test['arg_' + reduce], torch.long, device)
106-
assert torch.all(arg_out == arg_expected)
107-
assert torch.all(out == expected)
110+
assert torch.all(arg_out1 == arg_expected)
111+
assert arg_out1.tolist() == arg_out1.tolist()
112+
assert torch.all(out1 == expected)
113+
assert out1.tolist() == out2.tolist()
108114

109115

110116
@pytest.mark.parametrize('test,reduce,device',

test/test_segment.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,31 @@ def test_forward(test, reduce, dtype, device):
9191
indptr = tensor(test['indptr'], torch.long, device)
9292
expected = tensor(test[reduce], dtype, device)
9393

94-
out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr)
95-
if isinstance(out, tuple):
96-
out, arg_out = out
94+
fn = getattr(torch_scatter, 'segment_' + reduce + '_csr')
95+
jit = torch.jit.script(fn)
96+
out1 = fn(src, indptr)
97+
out2 = jit(src, indptr)
98+
if isinstance(out1, tuple):
99+
out1, arg_out1 = out1
100+
out2, arg_out2 = out2
97101
arg_expected = tensor(test['arg_' + reduce], torch.long, device)
98-
assert torch.all(arg_out == arg_expected)
99-
assert torch.all(out == expected)
100-
101-
out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index)
102-
if isinstance(out, tuple):
103-
out, arg_out = out
102+
assert torch.all(arg_out1 == arg_expected)
103+
assert arg_out1.tolist() == arg_out2.tolist()
104+
assert torch.all(out1 == expected)
105+
assert out1.tolist() == out2.tolist()
106+
107+
fn = getattr(torch_scatter, 'segment_' + reduce + '_coo')
108+
jit = torch.jit.script(fn)
109+
out1 = fn(src, index)
110+
out2 = jit(src, index)
111+
if isinstance(out1, tuple):
112+
out1, arg_out1 = out1
113+
out2, arg_out2 = out2
104114
arg_expected = tensor(test['arg_' + reduce], torch.long, device)
105-
assert torch.all(arg_out == arg_expected)
106-
assert torch.all(out == expected)
115+
assert torch.all(arg_out1 == arg_expected)
116+
assert arg_out1.tolist() == arg_out2.tolist()
117+
assert torch.all(out1 == expected)
118+
assert out1.tolist() == out2.tolist()
107119

108120

109121
@pytest.mark.parametrize('test,reduce,device',

torch_scatter/composite/logsumexp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch_scatter.utils import broadcast
77

88

9-
@torch.jit.script
109
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
1110
out: Optional[torch.Tensor] = None,
1211
dim_size: Optional[int] = None,

torch_scatter/composite/softmax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch_scatter.utils import broadcast
55

66

7-
@torch.jit.script
87
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
98
eps: float = 1e-12) -> torch.Tensor:
109
if not torch.is_floating_point(src):
@@ -25,7 +24,6 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2524
return recentered_scores_exp.div(normalizing_constants)
2625

2726

28-
@torch.jit.script
2927
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3028
eps: float = 1e-12) -> torch.Tensor:
3129
if not torch.is_floating_point(src):

torch_scatter/composite/std.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch_scatter.utils import broadcast
66

77

8-
@torch.jit.script
98
def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
109
out: Optional[torch.Tensor] = None,
1110
dim_size: Optional[int] = None,

torch_scatter/scatter.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from .utils import broadcast
66

77

8-
@torch.jit.script
98
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
109
out: Optional[torch.Tensor] = None,
1110
dim_size: Optional[int] = None) -> torch.Tensor:
@@ -24,21 +23,18 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2423
return out.scatter_add_(dim, index, src)
2524

2625

27-
@torch.jit.script
2826
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2927
out: Optional[torch.Tensor] = None,
3028
dim_size: Optional[int] = None) -> torch.Tensor:
3129
return scatter_sum(src, index, dim, out, dim_size)
3230

3331

34-
@torch.jit.script
3532
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3633
out: Optional[torch.Tensor] = None,
3734
dim_size: Optional[int] = None) -> torch.Tensor:
3835
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
3936

4037

41-
@torch.jit.script
4238
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
4339
out: Optional[torch.Tensor] = None,
4440
dim_size: Optional[int] = None) -> torch.Tensor:
@@ -63,15 +59,13 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
6359
return out
6460

6561

66-
@torch.jit.script
6762
def scatter_min(
6863
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
6964
out: Optional[torch.Tensor] = None,
7065
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
7166
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
7267

7368

74-
@torch.jit.script
7569
def scatter_max(
7670
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
7771
out: Optional[torch.Tensor] = None,

torch_scatter/segment_coo.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,35 @@
33
import torch
44

55

6-
@torch.jit.script
76
def segment_sum_coo(src: torch.Tensor, index: torch.Tensor,
87
out: Optional[torch.Tensor] = None,
98
dim_size: Optional[int] = None) -> torch.Tensor:
109
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
1110

1211

13-
@torch.jit.script
1412
def segment_add_coo(src: torch.Tensor, index: torch.Tensor,
1513
out: Optional[torch.Tensor] = None,
1614
dim_size: Optional[int] = None) -> torch.Tensor:
1715
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
1816

1917

20-
@torch.jit.script
2118
def segment_mean_coo(src: torch.Tensor, index: torch.Tensor,
2219
out: Optional[torch.Tensor] = None,
2320
dim_size: Optional[int] = None) -> torch.Tensor:
2421
return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size)
2522

2623

27-
@torch.jit.script
28-
def segment_min_coo(src: torch.Tensor, index: torch.Tensor,
29-
out: Optional[torch.Tensor] = None,
30-
dim_size: Optional[int] = None
31-
) -> Tuple[torch.Tensor, torch.Tensor]:
24+
def segment_min_coo(
25+
src: torch.Tensor, index: torch.Tensor,
26+
out: Optional[torch.Tensor] = None,
27+
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
3228
return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size)
3329

3430

35-
@torch.jit.script
36-
def segment_max_coo(src: torch.Tensor, index: torch.Tensor,
37-
out: Optional[torch.Tensor] = None,
38-
dim_size: Optional[int] = None
39-
) -> Tuple[torch.Tensor, torch.Tensor]:
31+
def segment_max_coo(
32+
src: torch.Tensor, index: torch.Tensor,
33+
out: Optional[torch.Tensor] = None,
34+
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
4035
return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size)
4136

4237

@@ -137,7 +132,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
137132
raise ValueError
138133

139134

140-
@torch.jit.script
141135
def gather_coo(src: torch.Tensor, index: torch.Tensor,
142136
out: Optional[torch.Tensor] = None) -> torch.Tensor:
143137
return torch.ops.torch_scatter.gather_coo(src, index, out)

0 commit comments

Comments
 (0)