11import os
2- import os . path as osp
2+ import sys
33import glob
4+ import os .path as osp
5+ from itertools import product
46from setuptools import setup , find_packages
57
68import torch
9+ from torch .__config__ import parallel_info
710from torch .utils .cpp_extension import BuildExtension
811from torch .utils .cpp_extension import CppExtension , CUDAExtension , CUDA_HOME
912
1013WITH_CUDA = torch .cuda .is_available () and CUDA_HOME is not None
14+ suffices = ['cpu' , 'cuda' ] if WITH_CUDA else ['cpu' ]
1115if os .getenv ('FORCE_CUDA' , '0' ) == '1' :
12- WITH_CUDA = True
13- if os .getenv ('FORCE_CPU' , '0' ) == '1' :
14- WITH_CUDA = False
16+ suffices = ['cuda' , 'cpu' ]
17+ if os .getenv ('FORCE_ONLY_CUDA' , '0' ) == '1' :
18+ suffices = ['cuda' ]
19+ if os .getenv ('FORCE_ONLY_CPU' , '0' ) == '1' :
20+ suffices = ['cpu' ]
1521
1622BUILD_DOCS = os .getenv ('BUILD_DOCS' , '0' ) == '1'
1723
1824
1925def get_extensions ():
20- Extension = CppExtension
21- define_macros = []
22- extra_compile_args = {'cxx' : ['-O2' ]}
23- extra_link_args = ['-s' ]
24-
25- if WITH_CUDA :
26- Extension = CUDAExtension
27- define_macros += [('WITH_CUDA' , None )]
28- nvcc_flags = os .getenv ('NVCC_FLAGS' , '' )
29- nvcc_flags = [] if nvcc_flags == '' else nvcc_flags .split (' ' )
30- nvcc_flags += ['-arch=sm_35' , '--expt-relaxed-constexpr' , '-O2' ]
31- extra_compile_args ['nvcc' ] = nvcc_flags
26+ extensions = []
3227
3328 extensions_dir = osp .join (osp .dirname (osp .abspath (__file__ )), 'csrc' )
3429 main_files = glob .glob (osp .join (extensions_dir , '*.cpp' ))
35- extensions = []
36- for main in main_files :
37- name = main .split (os .sep )[- 1 ][:- 4 ]
3830
31+ for main , suffix in product (main_files , suffices ):
32+ define_macros = []
33+ extra_compile_args = {'cxx' : ['-O2' ]}
34+ extra_link_args = ['-s' ]
35+
36+ info = parallel_info ()
37+ if 'backend: OpenMP' in info and 'OpenMP not found' not in info :
38+ extra_compile_args ['cxx' ] += ['-DAT_PARALLEL_OPENMP' ]
39+ if sys .platform == 'win32' :
40+ extra_compile_args ['cxx' ] += ['/openmp' ]
41+ else :
42+ extra_compile_args ['cxx' ] += ['-fopenmp' ]
43+ else :
44+ print ('Compiling without OpenMP...' )
45+
46+ if suffix == 'cuda' :
47+ define_macros += [('WITH_CUDA' , None )]
48+ nvcc_flags = os .getenv ('NVCC_FLAGS' , '' )
49+ nvcc_flags = [] if nvcc_flags == '' else nvcc_flags .split (' ' )
50+ nvcc_flags += ['-arch=sm_35' , '--expt-relaxed-constexpr' ]
51+ extra_compile_args ['nvcc' ] = nvcc_flags
52+
53+ name = main .split (os .sep )[- 1 ][:- 4 ]
3954 sources = [main ]
4055
4156 path = osp .join (extensions_dir , 'cpu' , f'{ name } _cpu.cpp' )
4257 if osp .exists (path ):
4358 sources += [path ]
4459
4560 path = osp .join (extensions_dir , 'cuda' , f'{ name } _cuda.cu' )
46- if WITH_CUDA and osp .exists (path ):
61+ if suffix == 'cuda' and osp .exists (path ):
4762 sources += [path ]
4863
64+ Extension = CppExtension if suffix == 'cpu' else CUDAExtension
4965 extension = Extension (
50- 'torch_scatter._' + name ,
66+ f 'torch_scatter._{ name } _ { suffix } ' ,
5167 sources ,
5268 include_dirs = [extensions_dir ],
5369 define_macros = define_macros ,
@@ -65,7 +81,7 @@ def get_extensions():
6581
6682setup (
6783 name = 'torch_scatter' ,
68- version = '2.0.5 ' ,
84+ version = '2.0.6 ' ,
6985 author = 'Matthias Fey' ,
70867187 url = 'https://github.com/rusty1s/pytorch_scatter' ,
0 commit comments