Skip to content

Commit 4774586

Browse files
committed
fix doc build
1 parent 8ca6d86 commit 4774586

File tree

3 files changed

+95
-8
lines changed

3 files changed

+95
-8
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[run]
22
source=torch_scatter
3+
omit=torch_scatter/placeholder.py
34
[report]
45
exclude_lines =
56
pragma: no cover

torch_scatter/__init__.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# flake8: noqa
2-
31
import os
42
import importlib
53
import os.path as osp
@@ -10,8 +8,9 @@
108
expected_torch_version = (1, 4)
119

1210
try:
13-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
14-
'_version', [osp.dirname(__file__)]).origin)
11+
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
12+
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
13+
library, [osp.dirname(__file__)]).origin)
1514
except OSError as e:
1615
if 'undefined symbol' in str(e):
1716
major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
@@ -25,6 +24,31 @@
2524
if os.getenv('BUILD_DOCS', '0') != '1':
2625
raise AttributeError(e)
2726

27+
from .placeholder import cuda_version_placeholder
28+
torch.ops.torch_scatter.cuda_version = cuda_version_placeholder
29+
30+
from .placeholder import scatter_arg_placeholder
31+
torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder
32+
torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder
33+
34+
from .placeholder import segment_csr_placeholder
35+
from .placeholder import segment_csr_arg_placeholder
36+
from .placeholder import gather_csr_placeholder
37+
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
38+
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
39+
torch.ops.torch_scatter.segment_min_csr = segment_csr_arg_placeholder
40+
torch.ops.torch_scatter.segment_max_csr = segment_csr_arg_placeholder
41+
torch.ops.torch_scatter.gather_csr = gather_csr_placeholder
42+
43+
from .placeholder import segment_coo_placeholder
44+
from .placeholder import segment_coo_arg_placeholder
45+
from .placeholder import gather_coo_placeholder
46+
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
47+
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
48+
torch.ops.torch_scatter.segment_min_coo = segment_coo_arg_placeholder
49+
torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder
50+
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
51+
2852
if torch.version.cuda is not None: # pragma: no cover
2953
cuda_version = torch.ops.torch_scatter.cuda_version()
3054

@@ -45,15 +69,15 @@
4569
f'matches your PyTorch install.')
4670

4771
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
48-
scatter_max, scatter)
72+
scatter_max, scatter) # noqa: E402
4973
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
5074
segment_min_csr, segment_max_csr, segment_csr,
51-
gather_csr)
75+
gather_csr) # noqa: E402
5276
from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
5377
segment_min_coo, segment_max_coo, segment_coo,
54-
gather_coo)
78+
gather_coo) # noqa: E402
5579
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
56-
scatter_log_softmax)
80+
scatter_log_softmax) # noqa: E402
5781

5882
__all__ = [
5983
'scatter_sum',

torch_scatter/placeholder.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
5+
6+
def cuda_version_placeholder() -> int:
7+
return -1
8+
9+
10+
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
11+
out: Optional[torch.Tensor],
12+
dim_size: Optional[int]) -> torch.Tensor:
13+
raise ImportError
14+
return src
15+
16+
17+
def scatter_arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
18+
out: Optional[torch.Tensor],
19+
dim_size: Optional[int]
20+
) -> Tuple[torch.Tensor, torch.Tensor]:
21+
raise ImportError
22+
return src, index
23+
24+
25+
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
26+
out: Optional[torch.Tensor]) -> torch.Tensor:
27+
raise ImportError
28+
return src
29+
30+
31+
def segment_csr_arg_placeholder(src: torch.Tensor, indptr: torch.Tensor,
32+
out: Optional[torch.Tensor]
33+
) -> Tuple[torch.Tensor, torch.Tensor]:
34+
raise ImportError
35+
return src, indptr
36+
37+
38+
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
39+
out: Optional[torch.Tensor]) -> torch.Tensor:
40+
raise ImportError
41+
return src
42+
43+
44+
def segment_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
45+
out: Optional[torch.Tensor],
46+
dim_size: Optional[int]) -> torch.Tensor:
47+
raise ImportError
48+
return src
49+
50+
51+
def segment_coo_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
52+
out: Optional[torch.Tensor],
53+
dim_size: Optional[int]
54+
) -> Tuple[torch.Tensor, torch.Tensor]:
55+
raise ImportError
56+
return src, index
57+
58+
59+
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
60+
out: Optional[torch.Tensor]) -> torch.Tensor:
61+
raise ImportError
62+
return src

0 commit comments

Comments
 (0)