Skip to content

Commit f5cb51a

Browse files
committed
stream to scatter kernels
1 parent 3cf59da commit f5cb51a

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

cuda/scatter_kernel.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/cuda/CUDAContext.h>
23
#include <ATen/cuda/detail/IndexUtils.cuh>
34
#include <ATen/cuda/detail/TensorInfo.cuh>
45

@@ -8,20 +9,23 @@
89
#define THREADS 1024
910
#define BLOCKS(N) (N + THREADS - 1) / THREADS
1011

12+
auto stream = at::cuda::getCurrentCUDAStream();
13+
1114
#define KERNEL_RUN(NAME, DIMS, N, ...) \
1215
[&] { \
16+
auto stream = at::cuda::getCurrentCUDAStream(); \
1317
switch (DIMS) { \
1418
case 1: \
15-
NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
19+
NAME<scalar_t, 1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
1620
break; \
1721
case 2: \
18-
NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
22+
NAME<scalar_t, 2><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
1923
break; \
2024
case 3: \
21-
NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
25+
NAME<scalar_t, 3><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
2226
break; \
2327
default: \
24-
NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
28+
NAME<scalar_t, -1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
2529
} \
2630
}()
2731

@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
4347

4448
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
4549
int64_t dim) {
46-
cudaSetDevice(src.get_device());
4750
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
4851
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
4952
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
7073

7174
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
7275
int64_t dim) {
73-
cudaSetDevice(src.get_device());
7476
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
7577
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
7678
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
116118

117119
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
118120
at::Tensor arg, int64_t dim) {
119-
cudaSetDevice(src.get_device());
120121
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
121122
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
122123
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
147148

148149
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
149150
at::Tensor arg, int64_t dim) {
150-
cudaSetDevice(src.get_device());
151151
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
152152
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
153153
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);

0 commit comments

Comments
 (0)