6868 \
6969 template <typename scalar, size_t size> struct Atomic ##NAME##DecimalImpl; \
7070 \
71+ template <typename scalar> struct Atomic ##NAME##DecimalImpl<scalar, 2 > { \
72+ inline __device__ void operator ()(scalar *address, scalar val) { \
73+ unsigned int *address_as_ui = \
74+ (unsigned int *)((char *)address - ((size_t )address & 2 )); \
75+ unsigned int old = *address_as_ui; \
76+ unsigned int assumed; \
77+ \
78+ do { \
79+ assumed = old; \
80+ at::Half hsum; \
81+ hsum.x = (size_t )address & 2 ? (old >> 16 ) : (old & 0xffff ); \
82+ hsum = OP (hsum, val); \
83+ old = (size_t )address & 2 ? (old & 0xffff ) | (hsum.x << 16 ) \
84+ : (old & 0xffff0000 ) | hsum.x ; \
85+ old = atomicCAS (address_as_ui, assumed, old); \
86+ } while (assumed != old); \
87+ } \
88+ }; \
89+ \
7190 template <typename scalar> struct Atomic ##NAME##DecimalImpl<scalar, 4 > { \
7291 inline __device__ void operator ()(scalar *address, scalar val) { \
7392 int *address_as_i = (int *)address; \
@@ -116,6 +135,15 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
116135static inline __device__ void atomAdd (int64_t *address, int64_t val) {
117136 AtomicAddIntegerImpl<int64_t , sizeof (int64_t )>()(address, val);
118137}
138+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000)
139+ static inline __device__ void atomAdd (at::Half *address, at::Half val) {
140+ AtomicAddDecimalImpl<at::Half, sizeof (at::Half)>()(address, val);
141+ }
142+ #else
143+ static inline __device__ void atomAdd (at::Half *address, at::Half val) {
144+ atomicAdd (reinterpret_cast <__half *>(address), val);
145+ }
146+ #endif
119147static inline __device__ void atomAdd (float *address, float val) {
120148 atomicAdd (address, val);
121149}
@@ -150,6 +178,9 @@ static inline __device__ void atomMul(int64_t *address, int64_t val) {
150178static inline __device__ void atomMul (float *address, float val) {
151179 AtomicMulDecimalImpl<float , sizeof (float )>()(address, val);
152180}
181+ static inline __device__ void atomMul (at::Half *address, at::Half val) {
182+ AtomicMulDecimalImpl<at::Half, sizeof (at::Half)>()(address, val);
183+ }
153184static inline __device__ void atomMul (double *address, double val) {
154185 AtomicMulDecimalImpl<double , sizeof (double )>()(address, val);
155186}
@@ -172,6 +203,9 @@ static inline __device__ void atomDiv(int32_t *address, int32_t val) {
172203static inline __device__ void atomDiv (int64_t *address, int64_t val) {
173204 AtomicDivIntegerImpl<int64_t , sizeof (int64_t )>()(address, val);
174205}
206+ static inline __device__ void atomDiv (at::Half *address, at::Half val) {
207+ AtomicDivDecimalImpl<at::Half, sizeof (at::Half)>()(address, val);
208+ }
175209static inline __device__ void atomDiv (float *address, float val) {
176210 AtomicDivDecimalImpl<float , sizeof (float )>()(address, val);
177211}
@@ -197,6 +231,9 @@ static inline __device__ void atomMax(int32_t *address, int32_t val) {
197231static inline __device__ void atomMax (int64_t *address, int64_t val) {
198232 AtomicMaxIntegerImpl<int64_t , sizeof (int64_t )>()(address, val);
199233}
234+ static inline __device__ void atomMax (at::Half *address, at::Half val) {
235+ AtomicMaxDecimalImpl<at::Half, sizeof (at::Half)>()(address, val);
236+ }
200237static inline __device__ void atomMax (float *address, float val) {
201238 AtomicMaxDecimalImpl<float , sizeof (float )>()(address, val);
202239}
@@ -222,6 +259,9 @@ static inline __device__ void atomMin(int32_t *address, int32_t val) {
222259static inline __device__ void atomMin (int64_t *address, int64_t val) {
223260 AtomicMinIntegerImpl<int64_t , sizeof (int64_t )>()(address, val);
224261}
262+ static inline __device__ void atomMin (at::Half *address, at::Half val) {
263+ AtomicMinDecimalImpl<at::Half, sizeof (at::Half)>()(address, val);
264+ }
225265static inline __device__ void atomMin (float *address, float val) {
226266 AtomicMinDecimalImpl<float , sizeof (float )>()(address, val);
227267}
0 commit comments