Skip to content

Commit 9a651d9

Browse files
authored
enable ROCm build; add BF16 for ROCm and CUDA (#325)
* first step, everything compiles * fix rebuilds; skip cuda version check for rocm * use macro for __shfl_up_sync __shfl_down_sync * add BFloat16 support for ROCm and CUDA * add USE_ROCM definition to setup.py * flake8 fixes
1 parent 18d3759 commit 9a651d9

File tree

9 files changed

+93
-17
lines changed

9 files changed

+93
-17
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ dist/
99
*.aux
1010
*.log
1111
*.pdf
12+
*.hip
13+
*_hip.cpp
14+
hip

csrc/cuda/atomics.cuh

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
\
6969
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
7070
\
71-
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 2> { \
72-
inline __device__ void operator()(scalar *address, scalar val) { \
71+
template <> struct Atomic##NAME##DecimalImpl<at::Half, 2> { \
72+
inline __device__ void operator()(at::Half *address, at::Half val) { \
7373
unsigned int *address_as_ui = \
7474
(unsigned int *)((char *)address - ((size_t)address & 2)); \
7575
unsigned int old = *address_as_ui; \
@@ -87,6 +87,25 @@
8787
} \
8888
}; \
8989
\
90+
template <> struct Atomic##NAME##DecimalImpl<at::BFloat16, 2> { \
91+
inline __device__ void operator()(at::BFloat16 *address, at::BFloat16 val){\
92+
unsigned int *address_as_ui = \
93+
(unsigned int *)((char *)address - ((size_t)address & 2)); \
94+
unsigned int old = *address_as_ui; \
95+
unsigned int assumed; \
96+
\
97+
do { \
98+
assumed = old; \
99+
at::BFloat16 hsum; \
100+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
101+
hsum = OP(hsum, val); \
102+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
103+
: (old & 0xffff0000) | hsum.x; \
104+
old = atomicCAS(address_as_ui, assumed, old); \
105+
} while (assumed != old); \
106+
} \
107+
}; \
108+
\
90109
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
91110
inline __device__ void operator()(scalar *address, scalar val) { \
92111
int *address_as_i = (int *)address; \
@@ -135,7 +154,7 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
135154
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
136155
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
137156
}
138-
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000)
157+
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
139158
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
140159
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
141160
}
@@ -156,6 +175,9 @@ static inline __device__ void atomAdd(double *address, double val) {
156175
atomicAdd(address, val);
157176
}
158177
#endif
178+
static inline __device__ void atomAdd(at::BFloat16 *address, at::BFloat16 val) {
179+
AtomicAddDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
180+
}
159181

160182
#define OP(X, Y) Y *X
161183
ATOMIC(Mul)
@@ -184,6 +206,9 @@ static inline __device__ void atomMul(at::Half *address, at::Half val) {
184206
static inline __device__ void atomMul(double *address, double val) {
185207
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
186208
}
209+
static inline __device__ void atomMul(at::BFloat16 *address, at::BFloat16 val) {
210+
AtomicMulDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
211+
}
187212

188213
#define OP(X, Y) Y / X
189214
ATOMIC(Div)
@@ -212,6 +237,9 @@ static inline __device__ void atomDiv(float *address, float val) {
212237
static inline __device__ void atomDiv(double *address, double val) {
213238
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
214239
}
240+
static inline __device__ void atomDiv(at::BFloat16 *address, at::BFloat16 val) {
241+
AtomicDivDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
242+
}
215243

216244
#define OP(X, Y) max(Y, X)
217245
ATOMIC(Max)
@@ -240,6 +268,9 @@ static inline __device__ void atomMax(float *address, float val) {
240268
static inline __device__ void atomMax(double *address, double val) {
241269
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
242270
}
271+
static inline __device__ void atomMax(at::BFloat16 *address, at::BFloat16 val) {
272+
AtomicMaxDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
273+
}
243274

244275
#define OP(X, Y) min(Y, X)
245276
ATOMIC(Min)
@@ -268,3 +299,6 @@ static inline __device__ void atomMin(float *address, float val) {
268299
static inline __device__ void atomMin(double *address, double val) {
269300
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
270301
}
302+
static inline __device__ void atomMin(at::BFloat16 *address, at::BFloat16 val) {
303+
AtomicMinDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
304+
}

csrc/cuda/scatter_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
111111

112112
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
113113
auto stream = at::cuda::getCurrentCUDAStream();
114-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
114+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
115115
auto src_data = src.data_ptr<scalar_t>();
116116
auto out_data = out.data_ptr<scalar_t>();
117117

csrc/cuda/segment_coo_cuda.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ segment_coo_kernel(const scalar_t *src_data,
3636
#pragma unroll
3737
for (int i = 1; i < 32; i *= 2) {
3838
// Parallel reduction inside a single warp.
39-
tmp = __shfl_up_sync(FULL_MASK, val, i);
40-
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
39+
tmp = SHFL_UP_SYNC(FULL_MASK, val, i);
40+
next_idx = SHFL_UP_SYNC(FULL_MASK, idx, i);
4141
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
4242
assert(idx >= next_idx);
4343
if (idx == next_idx)
4444
Reducer<scalar_t, REDUCE>::update(&val, tmp);
4545
}
4646
}
4747

48-
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
48+
next_idx = SHFL_DOWN_SYNC(FULL_MASK, idx, 1);
4949
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
5050
idx != next_idx)
5151
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
@@ -214,7 +214,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
214214

215215
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
216216
auto stream = at::cuda::getCurrentCUDAStream();
217-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
217+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
218218
auto src_data = src.data_ptr<scalar_t>();
219219
auto out_data = out.data_ptr<scalar_t>();
220220

@@ -365,7 +365,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
365365
366366
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
367367
auto stream = at::cuda::getCurrentCUDAStream();
368-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
368+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
369369
auto src_data = src.data_ptr<scalar_t>();
370370
auto out_data = out.data_ptr<scalar_t>();
371371

csrc/cuda/segment_csr_cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ segment_csr_kernel(const scalar_t *src_data,
4646
for (int i = TB / 2; i > 0; i /= 2) {
4747
// Parallel reduction inside a single warp.
4848
if (REDUCE == MIN || REDUCE == MAX)
49-
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
49+
arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i);
5050
Reducer<scalar_t, REDUCE>::update(
51-
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
51+
&val, SHFL_DOWN_SYNC(FULL_MASK, val, i), &arg, arg_tmp);
5252
}
5353

5454
if (lane_idx == 0) {
@@ -147,7 +147,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
147147

148148
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
149149
auto stream = at::cuda::getCurrentCUDAStream();
150-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
150+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
151151
auto src_data = src.data_ptr<scalar_t>();
152152
auto out_data = out.data_ptr<scalar_t>();
153153

@@ -264,7 +264,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
264264

265265
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
266266
auto stream = at::cuda::getCurrentCUDAStream();
267-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
267+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
268268
auto src_data = src.data_ptr<scalar_t>();
269269
auto out_data = out.data_ptr<scalar_t>();
270270

csrc/cuda/utils.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
1717
const unsigned int delta) {
1818
return __shfl_down_sync(mask, var.operator __half(), delta);
1919
}
20+
21+
#ifdef USE_ROCM
22+
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
23+
return __ldg(reinterpret_cast<const __half*>(ptr));
24+
}
25+
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
26+
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
27+
#else
28+
#define SHFL_UP_SYNC __shfl_up_sync
29+
#define SHFL_DOWN_SYNC __shfl_down_sync
30+
#endif

csrc/version.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
#include "macros.h"
88

99
#ifdef WITH_CUDA
10+
#ifdef USE_ROCM
11+
#include <hip/hip_version.h>
12+
#else
1013
#include <cuda.h>
1114
#endif
15+
#endif
1216

1317
#ifdef _WIN32
1418
#ifdef WITH_PYTHON
@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
2327
namespace scatter {
2428
SCATTER_API int64_t cuda_version() noexcept {
2529
#ifdef WITH_CUDA
30+
#ifdef USE_ROCM
31+
return HIP_VERSION;
32+
#else
2633
return CUDA_VERSION;
34+
#endif
2735
#else
2836
return -1;
2937
#endif

setup.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
__version__ = '2.0.9'
1515
URL = 'https://github.com/rusty1s/pytorch_scatter'
1616

17-
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
17+
WITH_CUDA = False
18+
if torch.cuda.is_available():
19+
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
1820
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
1921
if os.getenv('FORCE_CUDA', '0') == '1':
2022
suffices = ['cuda', 'cpu']
@@ -32,9 +34,12 @@ def get_extensions():
3234

3335
extensions_dir = osp.join('csrc')
3436
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
37+
# remove generated 'hip' files, in case of rebuilds
38+
main_files = [path for path in main_files if 'hip' not in path]
3539

3640
for main, suffix in product(main_files, suffices):
3741
define_macros = [('WITH_PYTHON', None)]
42+
undef_macros = []
3843

3944
if sys.platform == 'win32':
4045
define_macros += [('torchscatter_EXPORTS', None)]
@@ -64,7 +69,14 @@ def get_extensions():
6469
define_macros += [('WITH_CUDA', None)]
6570
nvcc_flags = os.getenv('NVCC_FLAGS', '')
6671
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
67-
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
72+
if torch.version.hip:
73+
nvcc_flags += ['-O3']
74+
# USE_ROCM was added to later versons of rocm pytorch
75+
# define here to support older pytorch versions
76+
define_macros += [('USE_ROCM', None)]
77+
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
78+
else:
79+
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
6880
extra_compile_args['nvcc'] = nvcc_flags
6981

7082
name = main.split(os.sep)[-1][:-4]
@@ -84,6 +96,7 @@ def get_extensions():
8496
sources,
8597
include_dirs=[extensions_dir],
8698
define_macros=define_macros,
99+
undef_macros=undef_macros,
87100
extra_compile_args=extra_compile_args,
88101
extra_link_args=extra_link_args,
89102
)
@@ -99,6 +112,11 @@ def get_extensions():
99112
'pytest-cov',
100113
]
101114

115+
# work-around hipify abs paths
116+
include_package_data = True
117+
if torch.cuda.is_available() and torch.version.hip:
118+
include_package_data = False
119+
102120
setup(
103121
name='torch_scatter',
104122
version=__version__,
@@ -119,5 +137,5 @@ def get_extensions():
119137
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
120138
},
121139
packages=find_packages(),
122-
include_package_data=True,
140+
include_package_data=include_package_data,
123141
)

torch_scatter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
4848

4949
cuda_version = torch.ops.torch_scatter.cuda_version()
50-
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
50+
is_not_hip = torch.version.hip is None
51+
is_cuda = torch.version.cuda is not None
52+
if is_not_hip and is_cuda and cuda_version != -1: # pragma: no cover
5153
if cuda_version < 10000:
5254
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
5355
else:

0 commit comments

Comments
 (0)