6868 \
6969 template <typename scalar, size_t size> struct Atomic ##NAME##DecimalImpl; \
7070 \
71- template <typename scalar > struct Atomic ##NAME##DecimalImpl<scalar , 2 > { \
72- inline __device__ void operator ()(scalar *address, scalar val) { \
71+ template <> struct Atomic ##NAME##DecimalImpl<at::Half , 2 > { \
72+ inline __device__ void operator ()(at::Half *address, at::Half val) { \
7373 unsigned int *address_as_ui = \
7474 (unsigned int *)((char *)address - ((size_t )address & 2 )); \
7575 unsigned int old = *address_as_ui; \
8787 } \
8888 }; \
8989 \
90+ template <> struct Atomic ##NAME##DecimalImpl<at::BFloat16, 2 > { \
91+ inline __device__ void operator ()(at::BFloat16 *address, at::BFloat16 val){\
92+ unsigned int *address_as_ui = \
93+ (unsigned int *)((char *)address - ((size_t )address & 2 )); \
94+ unsigned int old = *address_as_ui; \
95+ unsigned int assumed; \
96+ \
97+ do { \
98+ assumed = old; \
99+ at::BFloat16 hsum; \
100+ hsum.x = (size_t )address & 2 ? (old >> 16 ) : (old & 0xffff ); \
101+ hsum = OP (hsum, val); \
102+ old = (size_t )address & 2 ? (old & 0xffff ) | (hsum.x << 16 ) \
103+ : (old & 0xffff0000 ) | hsum.x ; \
104+ old = atomicCAS (address_as_ui, assumed, old); \
105+ } while (assumed != old); \
106+ } \
107+ }; \
108+ \
90109 template <typename scalar> struct Atomic ##NAME##DecimalImpl<scalar, 4 > { \
91110 inline __device__ void operator ()(scalar *address, scalar val) { \
92111 int *address_as_i = (int *)address; \
@@ -135,7 +154,7 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
135154static inline __device__ void atomAdd (int64_t *address, int64_t val) {
136155 AtomicAddIntegerImpl<int64_t , sizeof (int64_t )>()(address, val);
137156}
138- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000)
157+ #if defined(USE_ROCM) || (defined( __CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000) )
139158static inline __device__ void atomAdd (at::Half *address, at::Half val) {
140159 AtomicAddDecimalImpl<at::Half, sizeof (at::Half)>()(address, val);
141160}
@@ -156,6 +175,9 @@ static inline __device__ void atomAdd(double *address, double val) {
156175 atomicAdd (address, val);
157176}
158177#endif
178+ static inline __device__ void atomAdd (at::BFloat16 *address, at::BFloat16 val) {
179+ AtomicAddDecimalImpl<at::BFloat16, sizeof (at::BFloat16)>()(address, val);
180+ }
159181
160182#define OP (X, Y ) Y *X
161183ATOMIC (Mul)
@@ -184,6 +206,9 @@ static inline __device__ void atomMul(at::Half *address, at::Half val) {
184206static inline __device__ void atomMul (double *address, double val) {
185207 AtomicMulDecimalImpl<double , sizeof (double )>()(address, val);
186208}
209+ static inline __device__ void atomMul (at::BFloat16 *address, at::BFloat16 val) {
210+ AtomicMulDecimalImpl<at::BFloat16, sizeof (at::BFloat16)>()(address, val);
211+ }
187212
188213#define OP (X, Y ) Y / X
189214ATOMIC (Div)
@@ -212,6 +237,9 @@ static inline __device__ void atomDiv(float *address, float val) {
212237static inline __device__ void atomDiv (double *address, double val) {
213238 AtomicDivDecimalImpl<double , sizeof (double )>()(address, val);
214239}
240+ static inline __device__ void atomDiv (at::BFloat16 *address, at::BFloat16 val) {
241+ AtomicDivDecimalImpl<at::BFloat16, sizeof (at::BFloat16)>()(address, val);
242+ }
215243
216244#define OP (X, Y ) max(Y, X)
217245ATOMIC (Max)
@@ -240,6 +268,9 @@ static inline __device__ void atomMax(float *address, float val) {
240268static inline __device__ void atomMax (double *address, double val) {
241269 AtomicMaxDecimalImpl<double , sizeof (double )>()(address, val);
242270}
271+ static inline __device__ void atomMax (at::BFloat16 *address, at::BFloat16 val) {
272+ AtomicMaxDecimalImpl<at::BFloat16, sizeof (at::BFloat16)>()(address, val);
273+ }
243274
244275#define OP (X, Y ) min(Y, X)
245276ATOMIC (Min)
@@ -268,3 +299,6 @@ static inline __device__ void atomMin(float *address, float val) {
268299static inline __device__ void atomMin (double *address, double val) {
269300 AtomicMinDecimalImpl<double , sizeof (double )>()(address, val);
270301}
302+ static inline __device__ void atomMin (at::BFloat16 *address, at::BFloat16 val) {
303+ AtomicMinDecimalImpl<at::BFloat16, sizeof (at::BFloat16)>()(address, val);
304+ }
0 commit comments