Skip to content

Commit 20a7cd3

Browse files
committed
multi gpu update
1 parent b1072a5 commit 20a7cd3

File tree

4 files changed

+19
-2
lines changed

4 files changed

+19
-2
lines changed

cuda/scatter_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
4343

4444
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
4545
int64_t dim) {
46+
cudaSetDevice(src.get_device());
4647
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul_kernel", [&] {
4748
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
4849
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
@@ -69,6 +70,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
6970

7071
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
7172
int64_t dim) {
73+
cudaSetDevice(src.get_device());
7274
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div_kernel", [&] {
7375
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
7476
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
@@ -114,6 +116,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
114116

115117
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
116118
at::Tensor arg, int64_t dim) {
119+
cudaSetDevice(src.get_device());
117120
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max_kernel", [&] {
118121
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
119122
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
@@ -144,6 +147,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
144147

145148
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
146149
at::Tensor arg, int64_t dim) {
150+
cudaSetDevice(src.get_device());
147151
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min_kernel", [&] {
148152
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
149153
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
@@ -179,6 +183,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
179183

180184
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
181185
at::Tensor out, int64_t dim) {
186+
cudaSetDevice(grad.get_device());
182187
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward_kernel", [&] {
183188
KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
184189
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
2121
]
2222

23-
__version__ = '1.1.1'
23+
__version__ = '1.1.2'
2424
url = 'https://github.com/rusty1s/pytorch_scatter'
2525

2626
install_requires = []

test/test_multi_gpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pytest
2+
import torch
3+
from torch_scatter import scatter_max
4+
5+
6+
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
7+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS')
8+
def test_multi_gpu():
9+
device = torch.device('cuda:1')
10+
src = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device)
11+
index = torch.tensor([0, 0, 1, 1], device=device)
12+
assert scatter_max(src, index)[0].tolist() == [3, 5]

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .max import scatter_max
88
from .min import scatter_min
99

10-
__version__ = '1.1.1'
10+
__version__ = '1.1.2'
1111

1212
__all__ = [
1313
'scatter_add',

0 commit comments

Comments
 (0)