Skip to content

Commit 19df643

Browse files
committed
build both cpu and gpu binaries so same package can run on both CPU and GPU machines
1 parent 981731f commit 19df643

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

setup.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils.cpp_extension import BuildExtension
88
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
99

10-
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
10+
WITH_CUDA = CUDA_HOME is not None
1111
if os.getenv('FORCE_CUDA', '0') == '1':
1212
WITH_CUDA = True
1313
if os.getenv('FORCE_CPU', '0') == '1':
@@ -17,42 +17,48 @@
1717

1818

1919
def get_extensions():
20-
Extension = CppExtension
21-
define_macros = []
22-
extra_compile_args = {'cxx': []}
20+
extensions = []
21+
for with_cuda, supername in [
22+
(False, "cpu"),
23+
(True, "gpu"),
24+
]:
25+
if with_cuda and not WITH_CUDA:
26+
continue
27+
Extension = CppExtension
28+
define_macros = []
29+
extra_compile_args = {'cxx': []}
2330

24-
if WITH_CUDA:
25-
Extension = CUDAExtension
26-
define_macros += [('WITH_CUDA', None)]
27-
nvcc_flags = os.getenv('NVCC_FLAGS', '')
28-
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
29-
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
30-
extra_compile_args['nvcc'] = nvcc_flags
31+
if with_cuda:
32+
Extension = CUDAExtension
33+
define_macros += [('WITH_CUDA', None)]
34+
nvcc_flags = os.getenv('NVCC_FLAGS', '')
35+
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
36+
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
37+
extra_compile_args['nvcc'] = nvcc_flags
3138

32-
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
33-
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
34-
extensions = []
35-
for main in main_files:
36-
name = main.split(os.sep)[-1][:-4]
39+
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
40+
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
41+
for main in main_files:
42+
name = main.split(os.sep)[-1][:-4]
3743

38-
sources = [main]
44+
sources = [main]
3945

40-
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
41-
if osp.exists(path):
42-
sources += [path]
46+
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
47+
if osp.exists(path):
48+
sources += [path]
4349

44-
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
45-
if WITH_CUDA and osp.exists(path):
46-
sources += [path]
50+
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
51+
if with_cuda and osp.exists(path):
52+
sources += [path]
4753

48-
extension = Extension(
49-
'torch_scatter._' + name,
50-
sources,
51-
include_dirs=[extensions_dir],
52-
define_macros=define_macros,
53-
extra_compile_args=extra_compile_args,
54-
)
55-
extensions += [extension]
54+
extension = Extension(
55+
'torch_scatter._%s_%s' % (name, supername),
56+
sources,
57+
include_dirs=[extensions_dir],
58+
define_macros=define_macros,
59+
extra_compile_args=extra_compile_args,
60+
)
61+
extensions += [extension]
5662

5763
return extensions
5864

torch_scatter/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66

77
__version__ = '2.0.5'
88

9+
if torch.cuda.is_available():
10+
sublib = "gpu"
11+
else:
12+
sublib = "cpu"
13+
914
try:
1015
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
16+
library = "%s_%s" % (library, sublib)
1117
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
1218
library, [osp.dirname(__file__)]).origin)
1319
except AttributeError as e:

0 commit comments

Comments
 (0)