Skip to content

Commit 8594006

Browse files
committed
build docs
1 parent 920136e commit 8594006

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
if os.getenv('FORCE_NON_CUDA', '0') == '1':
1515
WITH_CUDA = False
1616

17+
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
18+
1719

1820
def get_extensions():
1921
Extension = CppExtension
@@ -74,7 +76,7 @@ def get_extensions():
7476
install_requires=install_requires,
7577
setup_requires=setup_requires,
7678
tests_require=tests_require,
77-
ext_modules=get_extensions(),
79+
ext_modules=get_extensions() if not BUILD_DOCS else [],
7880
cmdclass={
7981
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
8082
},

torch_scatter/scatter.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
1+
import warnings
12
import os.path as osp
23
from typing import Optional, Tuple
34

45
import torch
56

6-
torch.ops.load_library(
7-
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
7+
try:
8+
torch.ops.load_library(
9+
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
10+
except OSError:
11+
warnings.warn('Failed to load `scatter` binaries.')
12+
13+
def placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
14+
out: Optional[torch.Tensor],
15+
dim_size: Optional[int]) -> torch.Tensor:
16+
raise ImportError
17+
18+
def arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
19+
out: Optional[torch.Tensor], dim_size: Optional[int]
20+
) -> Tuple[torch.Tensor, torch.Tensor]:
21+
raise ImportError
22+
23+
torch.ops.torch_scatter.scatter_sum = placeholder
24+
torch.ops.torch_scatter.scatter_mean = placeholder
25+
torch.ops.torch_scatter.scatter_min = arg_placeholder
26+
torch.ops.torch_scatter.scatter_max = arg_placeholder
827

928

1029
@torch.jit.script

torch_scatter/segment_coo.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
1+
import warnings
12
import os.path as osp
23
from typing import Optional, Tuple
34

45
import torch
56

6-
torch.ops.load_library(
7-
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_coo.so'))
7+
try:
8+
torch.ops.load_library(
9+
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_coo.so'))
10+
except OSError:
11+
warnings.warn('Failed to load `segment_coo` binaries.')
12+
13+
def segment_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
14+
out: Optional[torch.Tensor],
15+
dim_size: Optional[int]) -> torch.Tensor:
16+
raise ImportError
17+
18+
def segment_coo_with_arg_placeholder(
19+
src: torch.Tensor, index: torch.Tensor,
20+
out: Optional[torch.Tensor],
21+
dim_size: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
22+
raise ImportError
23+
24+
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
25+
out: Optional[torch.Tensor]) -> torch.Tensor:
26+
raise ImportError
27+
28+
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
29+
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
30+
torch.ops.torch_scatter.segment_min_coo = segment_coo_with_arg_placeholder
31+
torch.ops.torch_scatter.segment_max_coo = segment_coo_with_arg_placeholder
32+
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
833

934

1035
@torch.jit.script

torch_scatter/segment_csr.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
1+
import warnings
12
import os.path as osp
23
from typing import Optional, Tuple
34

45
import torch
56

6-
torch.ops.load_library(
7-
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_csr.so'))
7+
try:
8+
torch.ops.load_library(
9+
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_csr.so'))
10+
except OSError:
11+
warnings.warn('Failed to load `segment_csr` binaries.')
12+
13+
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
14+
out: Optional[torch.Tensor]) -> torch.Tensor:
15+
raise ImportError
16+
17+
def segment_csr_with_arg_placeholder(
18+
src: torch.Tensor, indptr: torch.Tensor,
19+
out: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
20+
raise ImportError
21+
22+
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
23+
out: Optional[torch.Tensor]) -> torch.Tensor:
24+
raise ImportError
25+
26+
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
27+
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
28+
torch.ops.torch_scatter.segment_min_csr = segment_csr_with_arg_placeholder
29+
torch.ops.torch_scatter.segment_max_csr = segment_csr_with_arg_placeholder
30+
torch.ops.torch_scatter.gather_csr = gather_csr_placeholder
831

932

1033
@torch.jit.script

0 commit comments

Comments
 (0)