File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change 66 AT_ASSERTM (x.device().is_cuda(), #x " must be CUDA tensor" )
77#define CHECK_INPUT (x ) AT_ASSERTM(x, " Input mismatch" )
88
9- __device__ __inline__ at::Half __shfl_up_sync (const unsigned mask,
9+ // On ROCm, __shfl_*_sync requires a 64-bit mask; on CUDA it's 32-bit.
10+ #ifdef USE_ROCM
11+ using warp_mask_t = unsigned long long ;
12+ #else
13+ using warp_mask_t = unsigned int ;
14+ #endif
15+
16+ __device__ __inline__ at::Half __shfl_up_sync (const warp_mask_t mask,
1017 const at::Half var,
1118 const unsigned int delta) {
1219 return __shfl_up_sync (mask, var.operator __half (), delta);
1320}
1421
15- __device__ __inline__ at::Half __shfl_down_sync (const unsigned mask,
22+ __device__ __inline__ at::Half __shfl_down_sync (const warp_mask_t mask,
1623 const at::Half var,
1724 const unsigned int delta) {
1825 return __shfl_down_sync (mask, var.operator __half (), delta);
1926}
2027
21- __device__ __inline__ at::Half __shfl_sync (const unsigned mask,
28+ __device__ __inline__ at::Half __shfl_sync (const warp_mask_t mask,
2229 const at::Half var,
2330 const int delta) {
2431 return __shfl_sync (mask, var.operator __half (), delta);
You can’t perform that action at this time.
0 commit comments