Skip to content

Commit cea1dbe

Browse files
committed
Updated build flags
1 parent 88dfbd5 commit cea1dbe

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

setup.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from setuptools import setup, find_packages
2+
import os
3+
import glob
24

35
try:
46
import torch
@@ -11,14 +13,16 @@
1113
except:
1214
raise ModuleNotFoundError("Please install pytorch >= 1.1 before proceeding.")
1315

14-
import glob
15-
16-
from os import path
17-
18-
this_directory = path.abspath(path.dirname(__file__))
19-
with open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
20-
long_description = f.read()
21-
16+
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
17+
WITH_CPU = True
18+
if os.getenv('FORCE_CUDA', '0') == '1':
19+
WITH_CUDA = True
20+
if os.getenv('FORCE_ONLY_CUDA', '0') == '1':
21+
WITH_CUDA = True
22+
WITH_CPU = False
23+
if os.getenv('FORCE_ONLY_CPU', '0') == '1':
24+
WITH_CPU = True
25+
WITH_CUDA = False
2226

2327
def get_ext_modules():
2428
TORCH_MAJOR = int(torch.__version__.split(".")[0])
@@ -31,7 +35,7 @@ def get_ext_modules():
3135
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
3236

3337
ext_modules = []
34-
if CUDA_HOME:
38+
if WITH_CUDA:
3539
ext_modules.append(
3640
CUDAExtension(
3741
name="torch_points_kernels.points_cuda",
@@ -47,16 +51,17 @@ def get_ext_modules():
4751
cpu_ext_src_root = "cpu"
4852
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
4953

50-
ext_modules.append(
51-
CppExtension(
52-
name="torch_points_kernels.points_cpu",
53-
sources=cpu_ext_sources,
54-
include_dirs=["{}/include".format(cpu_ext_src_root)],
55-
extra_compile_args={
56-
"cxx": extra_compile_args,
57-
},
54+
if WITH_CPU:
55+
ext_modules.append(
56+
CppExtension(
57+
name="torch_points_kernels.points_cpu",
58+
sources=cpu_ext_sources,
59+
include_dirs=["{}/include".format(cpu_ext_src_root)],
60+
extra_compile_args={
61+
"cxx": extra_compile_args,
62+
},
63+
)
5864
)
59-
)
6065
return ext_modules
6166

6267

@@ -68,6 +73,9 @@ def __init__(self, *args, **kwargs):
6873
def get_cmdclass():
6974
return {"build_ext": CustomBuildExtension}
7075

76+
this_directory = os.path.abspath(os.path.dirname(__file__))
77+
with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
78+
long_description = f.read()
7179

7280
requirements = ["torch>=1.1.0", "numba", "scikit-learn"]
7381

0 commit comments

Comments
 (0)