Skip to content

Commit eeddc5d

Browse files
Allow pytorch lazy loading
1 parent a669090 commit eeddc5d

File tree

1 file changed

+47
-35
lines changed

1 file changed

+47
-35
lines changed

setup.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,65 @@
22

33
try:
44
import torch
5+
from torch.utils.cpp_extension import (
6+
BuildExtension,
7+
CUDAExtension,
8+
CUDA_HOME,
9+
CppExtension,
10+
)
11+
HAS_TORCH=True
512
except:
6-
raise ImportError("Please install pytorch before installing torch-points-kernels")
7-
8-
from torch.utils.cpp_extension import (
9-
BuildExtension,
10-
CUDAExtension,
11-
CUDA_HOME,
12-
CppExtension,
13-
)
13+
HAS_TORCH=False
14+
1415
import glob
1516

1617
from os import path
1718
this_directory = path.abspath(path.dirname(__file__))
1819
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
1920
long_description = f.read()
2021

21-
TORCH_MAJOR = int(torch.__version__.split(".")[0])
22-
TORCH_MINOR = int(torch.__version__.split(".")[1])
23-
extra_compile_args = ["-O3"]
24-
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
25-
extra_compile_args += ["-DVERSION_GE_1_3"]
2622

27-
ext_src_root = "cuda"
28-
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
23+
def get_ext_modules():
24+
if not HAS_TORCH:
25+
return []
26+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
27+
TORCH_MINOR = int(torch.__version__.split(".")[1])
28+
extra_compile_args = ["-O3"]
29+
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
30+
extra_compile_args += ["-DVERSION_GE_1_3"]
2931

30-
ext_modules = []
31-
if CUDA_HOME:
32-
ext_modules.append(
33-
CUDAExtension(
34-
name="torch_points_kernels.points_cuda",
35-
sources=ext_sources,
36-
include_dirs=["{}/include".format(ext_src_root)],
37-
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
32+
ext_src_root = "cuda"
33+
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
34+
35+
ext_modules = []
36+
if CUDA_HOME:
37+
ext_modules.append(
38+
CUDAExtension(
39+
name="torch_points_kernels.points_cuda",
40+
sources=ext_sources,
41+
include_dirs=["{}/include".format(ext_src_root)],
42+
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
43+
)
3844
)
39-
)
4045

41-
cpu_ext_src_root = "cpu"
42-
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
46+
cpu_ext_src_root = "cpu"
47+
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
4348

44-
ext_modules.append(
45-
CppExtension(
46-
name="torch_points_kernels.points_cpu",
47-
sources=cpu_ext_sources,
48-
include_dirs=["{}/include".format(cpu_ext_src_root)],
49-
extra_compile_args={"cxx": extra_compile_args,},
49+
ext_modules.append(
50+
CppExtension(
51+
name="torch_points_kernels.points_cpu",
52+
sources=cpu_ext_sources,
53+
include_dirs=["{}/include".format(cpu_ext_src_root)],
54+
extra_compile_args={"cxx": extra_compile_args,},
55+
)
5056
)
51-
)
57+
return ext_modules
58+
59+
def get_cmdclass():
60+
if HAS_TORCH:
61+
return {"build_ext": BuildExtension}
62+
else:
63+
return {}
5264

5365
requirements = ["torch>=1.1.0"]
5466

@@ -62,8 +74,8 @@
6274
url=url,
6375
download_url='{}/archive/{}.tar.gz'.format(url, __version__),
6476
install_requires=requirements,
65-
ext_modules=ext_modules,
66-
cmdclass={"build_ext": BuildExtension},
77+
ext_modules=get_ext_modules(),
78+
cmdclass=get_cmdclass(),
6779
long_description=long_description,
6880
long_description_content_type='text/markdown',
6981
classifiers=[

0 commit comments

Comments
 (0)