|
2 | 2 |
|
3 | 3 | try: |
4 | 4 | import torch |
| 5 | + from torch.utils.cpp_extension import ( |
| 6 | + BuildExtension, |
| 7 | + CUDAExtension, |
| 8 | + CUDA_HOME, |
| 9 | + CppExtension, |
| 10 | + ) |
| 11 | + HAS_TORCH=True |
5 | 12 | except: |
6 | | - raise ImportError("Please install pytorch before installing torch-points-kernels") |
7 | | - |
8 | | -from torch.utils.cpp_extension import ( |
9 | | - BuildExtension, |
10 | | - CUDAExtension, |
11 | | - CUDA_HOME, |
12 | | - CppExtension, |
13 | | -) |
| 13 | + HAS_TORCH=False |
| 14 | + |
14 | 15 | import glob |
15 | 16 |
|
16 | 17 | from os import path |
17 | 18 | this_directory = path.abspath(path.dirname(__file__)) |
18 | 19 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: |
19 | 20 | long_description = f.read() |
20 | 21 |
|
21 | | -TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
22 | | -TORCH_MINOR = int(torch.__version__.split(".")[1]) |
23 | | -extra_compile_args = ["-O3"] |
24 | | -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): |
25 | | - extra_compile_args += ["-DVERSION_GE_1_3"] |
26 | 22 |
|
27 | | -ext_src_root = "cuda" |
28 | | -ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root)) |
| 23 | +def get_ext_modules(): |
| 24 | + if not HAS_TORCH: |
| 25 | + return [] |
| 26 | + TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
| 27 | + TORCH_MINOR = int(torch.__version__.split(".")[1]) |
| 28 | + extra_compile_args = ["-O3"] |
| 29 | + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): |
| 30 | + extra_compile_args += ["-DVERSION_GE_1_3"] |
29 | 31 |
|
30 | | -ext_modules = [] |
31 | | -if CUDA_HOME: |
32 | | - ext_modules.append( |
33 | | - CUDAExtension( |
34 | | - name="torch_points_kernels.points_cuda", |
35 | | - sources=ext_sources, |
36 | | - include_dirs=["{}/include".format(ext_src_root)], |
37 | | - extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,}, |
| 32 | + ext_src_root = "cuda" |
| 33 | + ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root)) |
| 34 | + |
| 35 | + ext_modules = [] |
| 36 | + if CUDA_HOME: |
| 37 | + ext_modules.append( |
| 38 | + CUDAExtension( |
| 39 | + name="torch_points_kernels.points_cuda", |
| 40 | + sources=ext_sources, |
| 41 | + include_dirs=["{}/include".format(ext_src_root)], |
| 42 | + extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,}, |
| 43 | + ) |
38 | 44 | ) |
39 | | - ) |
40 | 45 |
|
41 | | -cpu_ext_src_root = "cpu" |
42 | | -cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root)) |
| 46 | + cpu_ext_src_root = "cpu" |
| 47 | + cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root)) |
43 | 48 |
|
44 | | -ext_modules.append( |
45 | | - CppExtension( |
46 | | - name="torch_points_kernels.points_cpu", |
47 | | - sources=cpu_ext_sources, |
48 | | - include_dirs=["{}/include".format(cpu_ext_src_root)], |
49 | | - extra_compile_args={"cxx": extra_compile_args,}, |
| 49 | + ext_modules.append( |
| 50 | + CppExtension( |
| 51 | + name="torch_points_kernels.points_cpu", |
| 52 | + sources=cpu_ext_sources, |
| 53 | + include_dirs=["{}/include".format(cpu_ext_src_root)], |
| 54 | + extra_compile_args={"cxx": extra_compile_args,}, |
| 55 | + ) |
50 | 56 | ) |
51 | | -) |
| 57 | + return ext_modules |
| 58 | + |
| 59 | +def get_cmdclass(): |
| 60 | + if HAS_TORCH: |
| 61 | + return {"build_ext": BuildExtension} |
| 62 | + else: |
| 63 | + return {} |
52 | 64 |
|
53 | 65 | requirements = ["torch>=1.1.0"] |
54 | 66 |
|
|
62 | 74 | url=url, |
63 | 75 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), |
64 | 76 | install_requires=requirements, |
65 | | - ext_modules=ext_modules, |
66 | | - cmdclass={"build_ext": BuildExtension}, |
| 77 | + ext_modules=get_ext_modules(), |
| 78 | + cmdclass=get_cmdclass(), |
67 | 79 | long_description=long_description, |
68 | 80 | long_description_content_type='text/markdown', |
69 | 81 | classifiers=[ |
|
0 commit comments