We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 38289bf commit 7b17897Copy full SHA for 7b17897
csrc/cuda/utils.cuh
@@ -6,6 +6,7 @@
6
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
7
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
8
9
+#ifndef USE_ROCM
10
__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
11
const at::Half var,
12
const unsigned int delta) {
@@ -17,6 +18,7 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
17
18
19
return __shfl_down_sync(mask, var.operator __half(), delta);
20
}
21
+#endif
22
23
__device__ __inline__ at::Half __shfl_up(const at::Half var,
24
0 commit comments