Skip to content

Commit d305ecc

Browse files
committed
nested extensions
1 parent 15afee0 commit d305ecc

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
ext_modules = [
66
CppExtension(
7-
'scatter_cpu', ['cpu/scatter.cpp'],
7+
'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
88
extra_compile_args=['-Wno-unused-variable'])
99
]
1010
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
1111

1212
if CUDA_HOME is not None:
1313
ext_modules += [
14-
CUDAExtension('scatter_cuda',
14+
CUDAExtension('torch_scatter.scatter_cuda',
1515
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
1616
]
1717

torch_scatter/utils/ext.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import torch
2-
import scatter_cpu
2+
import torch_scatter.scatter_cpu
33

44
if torch.cuda.is_available():
5-
import scatter_cuda
5+
import torch_scatter.scatter_cuda
66

77

88
def get_func(name, tensor):
9-
module = scatter_cuda if tensor.is_cuda else scatter_cpu
9+
if tensor.is_cuda:
10+
module = torch_scatter.scatter_cuda
11+
else:
12+
module = torch_scatter.scatter_cpu
1013
return getattr(module, name)

0 commit comments

Comments
 (0)