11#include < ATen/ATen.h>
2+ #include < ATen/cuda/CUDAContext.h>
23#include < ATen/cuda/detail/IndexUtils.cuh>
34#include < ATen/cuda/detail/TensorInfo.cuh>
45
89#define THREADS 1024
910#define BLOCKS (N ) (N + THREADS - 1 ) / THREADS
1011
12+ auto stream = at::cuda::getCurrentCUDAStream();
13+
1114#define KERNEL_RUN (NAME, DIMS, N, ...) \
1215 [&] { \
16+ auto stream = at::cuda::getCurrentCUDAStream (); \
1317 switch (DIMS) { \
1418 case 1 : \
15- NAME<scalar_t , 1 ><<<BLOCKS(N), THREADS>>> (__VA_ARGS__, N); \
19+ NAME<scalar_t , 1 ><<<BLOCKS(N), THREADS, 0 , stream >>> (__VA_ARGS__, N); \
1620 break ; \
1721 case 2 : \
18- NAME<scalar_t , 2 ><<<BLOCKS(N), THREADS>>> (__VA_ARGS__, N); \
22+ NAME<scalar_t , 2 ><<<BLOCKS(N), THREADS, 0 , stream >>> (__VA_ARGS__, N); \
1923 break ; \
2024 case 3 : \
21- NAME<scalar_t , 3 ><<<BLOCKS(N), THREADS>>> (__VA_ARGS__, N); \
25+ NAME<scalar_t , 3 ><<<BLOCKS(N), THREADS, 0 , stream >>> (__VA_ARGS__, N); \
2226 break ; \
2327 default : \
24- NAME<scalar_t , -1 ><<<BLOCKS(N), THREADS>>> (__VA_ARGS__, N); \
28+ NAME<scalar_t , -1 ><<<BLOCKS(N), THREADS, 0 , stream >>> (__VA_ARGS__, N); \
2529 } \
2630 }()
2731
@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
4347
4448void scatter_mul_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
4549 int64_t dim) {
46- cudaSetDevice (src.get_device ());
4750 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_mul_kernel" , [&] {
4851 KERNEL_RUN (scatter_mul_kernel, index.dim (), index.numel (),
4952 at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src),
@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
7073
7174void scatter_div_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
7275 int64_t dim) {
73- cudaSetDevice (src.get_device ());
7476 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_div_kernel" , [&] {
7577 KERNEL_RUN (scatter_div_kernel, index.dim (), index.numel (),
7678 at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src),
@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
116118
117119void scatter_max_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
118120 at::Tensor arg, int64_t dim) {
119- cudaSetDevice (src.get_device ());
120121 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_max_kernel" , [&] {
121122 auto src_info = at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src);
122123 auto index_info = at::cuda::detail::getTensorInfo<int64_t , int64_t >(index);
@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
147148
148149void scatter_min_cuda (at::Tensor src, at::Tensor index, at::Tensor out,
149150 at::Tensor arg, int64_t dim) {
150- cudaSetDevice (src.get_device ());
151151 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_min_kernel" , [&] {
152152 auto src_info = at::cuda::detail::getTensorInfo<scalar_t , int64_t >(src);
153153 auto index_info = at::cuda::detail::getTensorInfo<int64_t , int64_t >(index);
0 commit comments