@@ -43,6 +43,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
4343
4444void scatter_mul_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
4545 int64_t dim) {
46+ cudaSetDevice (src.get_device ());
4647 AT_DISPATCH_ALL_TYPES (src.type (), " scatter_mul_kernel" , [&] {
4748 KERNEL_RUN (scatter_mul_kernel, index.dim (), index.numel (),
4849 at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src),
@@ -69,6 +70,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
6970
7071void scatter_div_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
7172 int64_t dim) {
73+ cudaSetDevice (src.get_device ());
7274 AT_DISPATCH_ALL_TYPES (src.type (), " scatter_div_kernel" , [&] {
7375 KERNEL_RUN (scatter_div_kernel, index.dim (), index.numel (),
7476 at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src),
@@ -114,6 +116,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
114116
115117void scatter_max_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
116118 at::Tensor arg, int64_t dim) {
119+ cudaSetDevice (src.get_device ());
117120 AT_DISPATCH_ALL_TYPES (src.type (), " scatter_max_kernel" , [&] {
118121 auto src_info = at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src);
119122 auto index_info = at::cuda::detail::getTensorInfo<int64_t , int64_t >(index);
@@ -144,6 +147,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
144147
145148void scatter_min_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
146149 at::Tensor arg, int64_t dim) {
150+ cudaSetDevice (src.get_device ());
147151 AT_DISPATCH_ALL_TYPES (src.type (), " scatter_min_kernel" , [&] {
148152 auto src_info = at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src);
149153 auto index_info = at::cuda::detail::getTensorInfo<int64_t , int64_t >(index);
@@ -179,6 +183,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
179183
180184void index_backward_cuda (at::Tensor grad, at::Tensor index, at::Tensor arg,
181185 at::Tensor out, int64_t dim) {
186+ cudaSetDevice (grad.get_device ());
182187 AT_DISPATCH_ALL_TYPES (grad.type (), " index_backward_kernel" , [&] {
183188 KERNEL_RUN (index_backward_kernel, index.dim (), index.numel (),
184189 at::cuda::detail::getTensorInfo<scalar_t , int64_t >(grad),
0 commit comments