|
3 | 3 | import torch |
4 | 4 |
|
5 | 5 |
|
6 | | -@torch.jit.script |
7 | 6 | def segment_sum_coo(src: torch.Tensor, index: torch.Tensor, |
8 | 7 | out: Optional[torch.Tensor] = None, |
9 | 8 | dim_size: Optional[int] = None) -> torch.Tensor: |
10 | 9 | return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) |
11 | 10 |
|
12 | 11 |
|
13 | | -@torch.jit.script |
14 | 12 | def segment_add_coo(src: torch.Tensor, index: torch.Tensor, |
15 | 13 | out: Optional[torch.Tensor] = None, |
16 | 14 | dim_size: Optional[int] = None) -> torch.Tensor: |
17 | 15 | return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) |
18 | 16 |
|
19 | 17 |
|
20 | | -@torch.jit.script |
21 | 18 | def segment_mean_coo(src: torch.Tensor, index: torch.Tensor, |
22 | 19 | out: Optional[torch.Tensor] = None, |
23 | 20 | dim_size: Optional[int] = None) -> torch.Tensor: |
24 | 21 | return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size) |
25 | 22 |
|
26 | 23 |
|
27 | | -@torch.jit.script |
28 | | -def segment_min_coo(src: torch.Tensor, index: torch.Tensor, |
29 | | - out: Optional[torch.Tensor] = None, |
30 | | - dim_size: Optional[int] = None |
31 | | - ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 24 | +def segment_min_coo( |
| 25 | + src: torch.Tensor, index: torch.Tensor, |
| 26 | + out: Optional[torch.Tensor] = None, |
| 27 | + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
32 | 28 | return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size) |
33 | 29 |
|
34 | 30 |
|
35 | | -@torch.jit.script |
36 | | -def segment_max_coo(src: torch.Tensor, index: torch.Tensor, |
37 | | - out: Optional[torch.Tensor] = None, |
38 | | - dim_size: Optional[int] = None |
39 | | - ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 31 | +def segment_max_coo( |
| 32 | + src: torch.Tensor, index: torch.Tensor, |
| 33 | + out: Optional[torch.Tensor] = None, |
| 34 | + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
40 | 35 | return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) |
41 | 36 |
|
42 | 37 |
|
@@ -137,7 +132,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, |
137 | 132 | raise ValueError |
138 | 133 |
|
139 | 134 |
|
140 | | -@torch.jit.script |
141 | 135 | def gather_coo(src: torch.Tensor, index: torch.Tensor, |
142 | 136 | out: Optional[torch.Tensor] = None) -> torch.Tensor: |
143 | 137 | return torch.ops.torch_scatter.gather_coo(src, index, out) |
0 commit comments