Skip to content

Commit 2340737

Browse files
authored
Add ROCm 6.4.3+ support
1 parent cdfbc7e commit 2340737

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

csrc/cuda/utils.cuh

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,26 @@
66
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
77
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
88

9-
__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
9+
// On ROCm, __shfl_*_sync requires a 64-bit mask; on CUDA it's 32-bit.
10+
#ifdef USE_ROCM
11+
using warp_mask_t = unsigned long long;
12+
#else
13+
using warp_mask_t = unsigned int;
14+
#endif
15+
16+
__device__ __inline__ at::Half __shfl_up_sync(const warp_mask_t mask,
1017
const at::Half var,
1118
const unsigned int delta) {
1219
return __shfl_up_sync(mask, var.operator __half(), delta);
1320
}
1421

15-
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
22+
__device__ __inline__ at::Half __shfl_down_sync(const warp_mask_t mask,
1623
const at::Half var,
1724
const unsigned int delta) {
1825
return __shfl_down_sync(mask, var.operator __half(), delta);
1926
}
2027

21-
__device__ __inline__ at::Half __shfl_sync(const unsigned mask,
28+
__device__ __inline__ at::Half __shfl_sync(const warp_mask_t mask,
2229
const at::Half var,
2330
const int delta) {
2431
return __shfl_sync(mask, var.operator __half(), delta);

0 commit comments

Comments
 (0)