11from setuptools import setup , find_packages
2+ import os
3+ import glob
24
35try :
46 import torch
1113except :
1214 raise ModuleNotFoundError ("Please install pytorch >= 1.1 before proceeding." )
1315
14- import glob
15-
16- from os import path
17-
18- this_directory = path .abspath (path .dirname (__file__ ))
19- with open (path .join (this_directory , "README.md" ), encoding = "utf-8" ) as f :
20- long_description = f .read ()
21-
16+ WITH_CUDA = torch .cuda .is_available () and CUDA_HOME is not None
17+ WITH_CPU = True
18+ if os .getenv ('FORCE_CUDA' , '0' ) == '1' :
19+ WITH_CUDA = True
20+ if os .getenv ('FORCE_ONLY_CUDA' , '0' ) == '1' :
21+ WITH_CUDA = True
22+ WITH_CPU = False
23+ if os .getenv ('FORCE_ONLY_CPU' , '0' ) == '1' :
24+ WITH_CUDA = False
25+ WITH_CPU = True
2226
2327def get_ext_modules ():
2428 TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
@@ -31,7 +35,7 @@ def get_ext_modules():
3135 ext_sources = glob .glob ("{}/src/*.cpp" .format (ext_src_root )) + glob .glob ("{}/src/*.cu" .format (ext_src_root ))
3236
3337 ext_modules = []
34- if CUDA_HOME :
38+ if WITH_CUDA :
3539 ext_modules .append (
3640 CUDAExtension (
3741 name = "torch_points_kernels.points_cuda" ,
@@ -47,16 +51,17 @@ def get_ext_modules():
4751 cpu_ext_src_root = "cpu"
4852 cpu_ext_sources = glob .glob ("{}/src/*.cpp" .format (cpu_ext_src_root ))
4953
50- ext_modules .append (
51- CppExtension (
52- name = "torch_points_kernels.points_cpu" ,
53- sources = cpu_ext_sources ,
54- include_dirs = ["{}/include" .format (cpu_ext_src_root )],
55- extra_compile_args = {
56- "cxx" : extra_compile_args ,
57- },
54+ if WITH_CPU :
55+ ext_modules .append (
56+ CppExtension (
57+ name = "torch_points_kernels.points_cpu" ,
58+ sources = cpu_ext_sources ,
59+ include_dirs = ["{}/include" .format (cpu_ext_src_root )],
60+ extra_compile_args = {
61+ "cxx" : extra_compile_args ,
62+ },
63+ )
5864 )
59- )
6065 return ext_modules
6166
6267
@@ -68,8 +73,11 @@ def __init__(self, *args, **kwargs):
6873def get_cmdclass ():
6974 return {"build_ext" : CustomBuildExtension }
7075
76+ this_directory = os .path .abspath (os .path .dirname (__file__ ))
77+ with open (os .path .join (this_directory , "README.md" ), encoding = "utf-8" ) as f :
78+ long_description = f .read ()
7179
72- requirements = ["torch>=1.1.0" , "numba" , "scikit-learn" ]
80+ requirements = ["torch>=1.1.0" , "numba" , "numpy<1.20" , " scikit-learn" ]
7381
7482url = "https://github.com/nicolas-chaulet/torch-points-kernels"
7583__version__ = "0.6.10"
0 commit comments