11import os
2- import os . path as osp
2+ import sys
33import glob
4+ import os .path as osp
5+ from itertools import product
46from setuptools import setup , find_packages
57
68import torch
9+ from torch .__config__ import parallel_info
710from torch .utils .cpp_extension import BuildExtension
811from torch .utils .cpp_extension import CppExtension , CUDAExtension , CUDA_HOME
912
10- WITH_CUDA = CUDA_HOME is not None
13+ WITH_CUDA = torch .cuda .is_available () and CUDA_HOME is not None
14+ suffices = ['cpu' , 'cuda' ] if WITH_CUDA else ['cpu' ]
1115if os .getenv ('FORCE_CUDA' , '0' ) == '1' :
12- WITH_CUDA = True
13- if os .getenv ('FORCE_CPU' , '0' ) == '1' :
14- WITH_CUDA = False
16+ suffices = ['cuda' , 'cpu' ]
17+ if os .getenv ('FORCE_ONLY_CUDA' , '0' ) == '1' :
18+ suffices = ['cuda' ]
19+ if os .getenv ('FORCE_ONLY_CPU' , '0' ) == '1' :
20+ suffices = ['cpu' ]
1521
1622BUILD_DOCS = os .getenv ('BUILD_DOCS' , '0' ) == '1'
1723
1824
1925def get_extensions ():
2026 extensions = []
21- for with_cuda , supername in [
22- (False , "cpu" ),
23- (True , "gpu" ),
24- ]:
25- if with_cuda and not WITH_CUDA :
26- continue
27- Extension = CppExtension
27+
28+ extensions_dir = osp .join (osp .dirname (osp .abspath (__file__ )), 'csrc' )
29+ main_files = glob .glob (osp .join (extensions_dir , '*.cpp' ))
30+
31+ for main , suffix in product (main_files , suffices ):
2832 define_macros = []
29- extra_compile_args = {'cxx' : []}
33+ extra_compile_args = {'cxx' : ['-O2' ]}
34+ extra_link_args = ['-s' ]
35+
36+ info = parallel_info ()
37+ if 'backend: OpenMP' in info and 'OpenMP not found' not in info :
38+ extra_compile_args ['cxx' ] += ['-DAT_PARALLEL_OPENMP' ]
39+ if sys .platform == 'win32' :
40+ extra_compile_args ['cxx' ] += ['/openmp' ]
41+ else :
42+ extra_compile_args ['cxx' ] += ['-fopenmp' ]
43+ else :
44+ print ('Compiling without OpenMP...' )
3045
31- if with_cuda :
32- Extension = CUDAExtension
46+ if suffix == 'cuda' :
3347 define_macros += [('WITH_CUDA' , None )]
3448 nvcc_flags = os .getenv ('NVCC_FLAGS' , '' )
3549 nvcc_flags = [] if nvcc_flags == '' else nvcc_flags .split (' ' )
3650 nvcc_flags += ['-arch=sm_35' , '--expt-relaxed-constexpr' ]
3751 extra_compile_args ['nvcc' ] = nvcc_flags
3852
39- extensions_dir = osp .join (osp .dirname (osp .abspath (__file__ )), 'csrc' )
40- main_files = glob .glob (osp .join (extensions_dir , '*.cpp' ))
41- for main in main_files :
42- name = main .split (os .sep )[- 1 ][:- 4 ]
43-
44- sources = [main ]
53+ name = main .split (os .sep )[- 1 ][:- 4 ]
54+ sources = [main ]
4555
46- path = osp .join (extensions_dir , 'cpu' , f'{ name } _cpu.cpp' )
47- if osp .exists (path ):
48- sources += [path ]
56+ path = osp .join (extensions_dir , 'cpu' , f'{ name } _cpu.cpp' )
57+ if osp .exists (path ):
58+ sources += [path ]
4959
50- path = osp .join (extensions_dir , 'cuda' , f'{ name } _cuda.cu' )
51- if with_cuda and osp .exists (path ):
52- sources += [path ]
60+ path = osp .join (extensions_dir , 'cuda' , f'{ name } _cuda.cu' )
61+ if suffix == 'cuda' and osp .exists (path ):
62+ sources += [path ]
5363
54- extension = Extension (
55- 'torch_scatter._%s_%s' % (name , supername ),
56- sources ,
57- include_dirs = [extensions_dir ],
58- define_macros = define_macros ,
59- extra_compile_args = extra_compile_args ,
60- )
61- extensions += [extension ]
64+ Extension = CppExtension if suffix == 'cpu' else CUDAExtension
65+ extension = Extension (
66+ f'torch_scatter._{ name } _{ suffix } ' ,
67+ sources ,
68+ include_dirs = [extensions_dir ],
69+ define_macros = define_macros ,
70+ extra_compile_args = extra_compile_args ,
71+ extra_link_args = extra_link_args ,
72+ )
73+ extensions += [extension ]
6274
6375 return extensions
6476
@@ -69,7 +81,7 @@ def get_extensions():
6981
7082setup (
7183 name = 'torch_scatter' ,
72- version = '2.0.5 ' ,
84+ version = '2.0.6 ' ,
7385 author = 'Matthias Fey' ,
74867587 url = 'https://github.com/rusty1s/pytorch_scatter' ,
0 commit comments