Skip to content

Commit 5c93eb3

Browse files
Merge pull request #63 from CCInc/master
Update build flags to support other enviornments
2 parents 88dfbd5 + 78b9f66 commit 5c93eb3

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

.github/workflows/deploy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Install dependencies
1818
run: |
1919
python -m pip install --upgrade pip
20-
pip install torch numpy scikit-learn flake8 setuptools wheel twine numba
20+
pip install torch "numpy<1.20" scikit-learn flake8 setuptools wheel twine numba
2121
- name: Build package
2222
run: |
2323
python setup.py build_ext --inplace

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Install dependencies
2323
run: |
2424
python -m pip install --upgrade pip
25-
pip install numpy scikit-learn flake8 setuptools numba==0.49.1
25+
pip install "numpy<1.20" scikit-learn flake8 setuptools numba==0.49.1
2626
2727
- name: Install torch windows + linux
2828
if: ${{matrix.os != 'macos-latest'}}

setup.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from setuptools import setup, find_packages
2+
import os
3+
import glob
24

35
try:
46
import torch
@@ -11,14 +13,16 @@
1113
except:
1214
raise ModuleNotFoundError("Please install pytorch >= 1.1 before proceeding.")
1315

14-
import glob
15-
16-
from os import path
17-
18-
this_directory = path.abspath(path.dirname(__file__))
19-
with open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
20-
long_description = f.read()
21-
16+
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
17+
WITH_CPU = True
18+
if os.getenv('FORCE_CUDA', '0') == '1':
19+
WITH_CUDA = True
20+
if os.getenv('FORCE_ONLY_CUDA', '0') == '1':
21+
WITH_CUDA = True
22+
WITH_CPU = False
23+
if os.getenv('FORCE_ONLY_CPU', '0') == '1':
24+
WITH_CUDA = False
25+
WITH_CPU = True
2226

2327
def get_ext_modules():
2428
TORCH_MAJOR = int(torch.__version__.split(".")[0])
@@ -31,7 +35,7 @@ def get_ext_modules():
3135
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
3236

3337
ext_modules = []
34-
if CUDA_HOME:
38+
if WITH_CUDA:
3539
ext_modules.append(
3640
CUDAExtension(
3741
name="torch_points_kernels.points_cuda",
@@ -47,16 +51,17 @@ def get_ext_modules():
4751
cpu_ext_src_root = "cpu"
4852
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
4953

50-
ext_modules.append(
51-
CppExtension(
52-
name="torch_points_kernels.points_cpu",
53-
sources=cpu_ext_sources,
54-
include_dirs=["{}/include".format(cpu_ext_src_root)],
55-
extra_compile_args={
56-
"cxx": extra_compile_args,
57-
},
54+
if WITH_CPU:
55+
ext_modules.append(
56+
CppExtension(
57+
name="torch_points_kernels.points_cpu",
58+
sources=cpu_ext_sources,
59+
include_dirs=["{}/include".format(cpu_ext_src_root)],
60+
extra_compile_args={
61+
"cxx": extra_compile_args,
62+
},
63+
)
5864
)
59-
)
6065
return ext_modules
6166

6267

@@ -68,8 +73,11 @@ def __init__(self, *args, **kwargs):
6873
def get_cmdclass():
6974
return {"build_ext": CustomBuildExtension}
7075

76+
this_directory = os.path.abspath(os.path.dirname(__file__))
77+
with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
78+
long_description = f.read()
7179

72-
requirements = ["torch>=1.1.0", "numba", "scikit-learn"]
80+
requirements = ["torch>=1.1.0", "numba", "numpy<1.20", "scikit-learn"]
7381

7482
url = "https://github.com/nicolas-chaulet/torch-points-kernels"
7583
__version__ = "0.6.10"

0 commit comments

Comments
 (0)