|
5 | 5 | #include <mutex> |
6 | 6 | #include <pybind11/numpy.h> |
7 | 7 | #include <pybind11/pybind11.h> |
| 8 | +#include <stdexcept> |
8 | 9 | #include <type_traits> |
9 | 10 |
|
10 | 11 | namespace py = pybind11; |
11 | 12 |
|
12 | 13 | namespace { |
13 | 14 |
|
| 15 | +struct npy_half { |
| 16 | + uint16_t value; |
| 17 | +}; |
| 18 | + |
14 | 19 | enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; |
15 | 20 |
|
16 | 21 | std::mutex atomic_op_guard; |
@@ -83,6 +88,211 @@ template <typename T> T atomic_fadd(T *loc, T value, std::memory_order order) { |
83 | 88 | return old_value; |
84 | 89 | } |
85 | 90 |
|
| 91 | +/** Create a value of type `To` from the bits of `from`. |
| 92 | + * |
| 93 | + * similar to `std::bit_cast` but compatible with C++17, |
| 94 | + * should perform similar to `*reinterpret_cast<To*>(&from)` |
| 95 | + * or through punning without expecting any undefined behaviors. |
| 96 | + * |
| 97 | + * Note: taken from |
| 98 | + * https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32 |
| 99 | + * with simplification. |
| 100 | + */ |
| 101 | +template <typename To, typename From> |
| 102 | +inline To BitCast(const From &from) noexcept { |
| 103 | + static_assert(sizeof(To) == sizeof(From), |
| 104 | + "both data types must have the same size"); |
| 105 | + |
| 106 | + static_assert(std::is_trivially_copyable_v<To> && |
| 107 | + std::is_trivially_copyable_v<From>, |
| 108 | + "both data types must be trivially copyable"); |
| 109 | + |
| 110 | + To to; |
| 111 | + memcpy(&to, &from, sizeof(from)); |
| 112 | + return to; |
| 113 | +} |
| 114 | + |
| 115 | +// Taken from |
| 116 | +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14 |
| 117 | +template <bool gen_overflow = true, bool gen_underflow = true, |
| 118 | + bool round_even = true> |
| 119 | +inline uint16_t FromFloatBits(uint32_t f) { |
| 120 | + uint32_t f_exp, f_sig; |
| 121 | + uint16_t h_sgn, h_exp, h_sig; |
| 122 | + |
| 123 | + h_sgn = (uint16_t)((f & 0x80000000u) >> 16); |
| 124 | + f_exp = (f & 0x7f800000u); |
| 125 | + |
| 126 | + /* Exponent overflow/NaN converts to signed inf/NaN */ |
| 127 | + if (f_exp >= 0x47800000u) { |
| 128 | + if (f_exp == 0x7f800000u) { |
| 129 | + /* Inf or NaN */ |
| 130 | + f_sig = (f & 0x007fffffu); |
| 131 | + if (f_sig != 0) { |
| 132 | + /* NaN - propagate the flag in the significand... */ |
| 133 | + uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13)); |
| 134 | + /* ...but make sure it stays a NaN */ |
| 135 | + if (ret == 0x7c00u) { |
| 136 | + ret++; |
| 137 | + } |
| 138 | + return h_sgn + ret; |
| 139 | + } else { |
| 140 | + /* signed inf */ |
| 141 | + return (uint16_t)(h_sgn + 0x7c00u); |
| 142 | + } |
| 143 | + } else { |
| 144 | + if constexpr (gen_overflow) { |
| 145 | + // FloatStatus::RaiseOverflow(); |
| 146 | + throw std::overflow_error("overflow to signed inf"); |
| 147 | + } |
| 148 | + return (uint16_t)(h_sgn + 0x7c00u); |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + /* Exponent underflow converts to a subnormal half or signed zero */ |
| 153 | + if (f_exp <= 0x38000000u) { |
| 154 | + /* |
| 155 | + * Signed zeros, subnormal floats, and floats with small |
| 156 | + * exponents all convert to signed zero half-floats. |
| 157 | + */ |
| 158 | + if (f_exp < 0x33000000u) { |
| 159 | + if constexpr (gen_underflow) { |
| 160 | + /* If f != 0, it underflowed to 0 */ |
| 161 | + if ((f & 0x7fffffff) != 0) { |
| 162 | + // FloatStatus::RaiseUnderflow(); |
| 163 | + throw std::underflow_error(""); |
| 164 | + } |
| 165 | + } |
| 166 | + return h_sgn; |
| 167 | + } |
| 168 | + /* Make the subnormal significand */ |
| 169 | + f_exp >>= 23; |
| 170 | + f_sig = (0x00800000u + (f & 0x007fffffu)); |
| 171 | + if constexpr (gen_underflow) { |
| 172 | + /* If it's not exactly represented, it underflowed */ |
| 173 | + if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) { |
| 174 | + // FloatStatus::RaiseUnderflow(); |
| 175 | + throw std::underflow_error(""); |
| 176 | + } |
| 177 | + } |
| 178 | + /* |
| 179 | + * Usually the significand is shifted by 13. For subnormals an |
| 180 | + * additional shift needs to occur. This shift is one for the largest |
| 181 | + * exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which |
| 182 | + * offsets the new first bit. At most the shift can be 1+10 bits. |
| 183 | + */ |
| 184 | + f_sig >>= (113 - f_exp); |
| 185 | + /* Handle rounding by adding 1 to the bit beyond half precision */ |
| 186 | + if constexpr (round_even) { |
| 187 | + /* |
| 188 | + * If the last bit in the half significand is 0 (already even), and |
| 189 | + * the remaining bit pattern is 1000...0, then we do not add one |
| 190 | + * to the bit after the half significand. However, the (113 - f_exp) |
| 191 | + * shift can lose up to 11 bits, so the || checks them in the original. |
| 192 | + * In all other cases, we can just add one. |
| 193 | + */ |
| 194 | + if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) { |
| 195 | + f_sig += 0x00001000u; |
| 196 | + } |
| 197 | + } else { |
| 198 | + f_sig += 0x00001000u; |
| 199 | + } |
| 200 | + h_sig = (uint16_t)(f_sig >> 13); |
| 201 | + /* |
| 202 | + * If the rounding causes a bit to spill into h_exp, it will |
| 203 | + * increment h_exp from zero to one and h_sig will be zero. |
| 204 | + * This is the correct result. |
| 205 | + */ |
| 206 | + return (uint16_t)(h_sgn + h_sig); |
| 207 | + } |
| 208 | + |
| 209 | + /* Regular case with no overflow or underflow */ |
| 210 | + h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13); |
| 211 | + /* Handle rounding by adding 1 to the bit beyond half precision */ |
| 212 | + f_sig = (f & 0x007fffffu); |
| 213 | + if constexpr (round_even) { |
| 214 | + /* |
| 215 | + * If the last bit in the half significand is 0 (already even), and |
| 216 | + * the remaining bit pattern is 1000...0, then we do not add one |
| 217 | + * to the bit after the half significand. In all other cases, we do. |
| 218 | + */ |
| 219 | + if ((f_sig & 0x00003fffu) != 0x00001000u) { |
| 220 | + f_sig += 0x00001000u; |
| 221 | + } |
| 222 | + } else { |
| 223 | + f_sig += 0x00001000u; |
| 224 | + } |
| 225 | + h_sig = (uint16_t)(f_sig >> 13); |
| 226 | + /* |
| 227 | + * If the rounding causes a bit to spill into h_exp, it will |
| 228 | + * increment h_exp by one and h_sig will be zero. This is the |
| 229 | + * correct result. h_exp may increment to 15, at greatest, in |
| 230 | + * which case the result overflows to a signed inf. |
| 231 | + */ |
| 232 | + if constexpr (gen_overflow) { |
| 233 | + h_sig += h_exp; |
| 234 | + if (h_sig == 0x7c00u) { |
| 235 | + // FloatStatus::RaiseOverflow(); |
| 236 | + throw std::overflow_error(""); |
| 237 | + } |
| 238 | + return h_sgn + h_sig; |
| 239 | + } else { |
| 240 | + return h_sgn + h_exp + h_sig; |
| 241 | + } |
| 242 | +} |
| 243 | + |
| 244 | +// Taken from |
| 245 | +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269 |
| 246 | +constexpr uint32_t ToFloatBits(uint16_t h) { |
| 247 | + uint16_t h_exp = (h & 0x7c00u); |
| 248 | + uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16; |
| 249 | + switch (h_exp) { |
| 250 | + case 0x0000u: { // 0 or subnormal |
| 251 | + uint16_t h_sig = (h & 0x03ffu); |
| 252 | + // Signed zero |
| 253 | + if (h_sig == 0) { |
| 254 | + return f_sgn; |
| 255 | + } |
| 256 | + // Subnormal |
| 257 | + h_sig <<= 1; |
| 258 | + while ((h_sig & 0x0400u) == 0) { |
| 259 | + h_sig <<= 1; |
| 260 | + h_exp++; |
| 261 | + } |
| 262 | + uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; |
| 263 | + uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; |
| 264 | + return f_sgn + f_exp + f_sig; |
| 265 | + } |
| 266 | + case 0x7c00u: // inf or NaN |
| 267 | + // All-ones exponent and a copy of the significand |
| 268 | + return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13); |
| 269 | + default: // normalized |
| 270 | + // Just need to adjust the exponent and shift |
| 271 | + return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13); |
| 272 | + } |
| 273 | +} |
| 274 | + |
| 275 | +npy_half npy_float_to_half(float f) { |
| 276 | + return {FromFloatBits(BitCast<uint32_t>(f))}; |
| 277 | +} |
| 278 | + |
| 279 | +float npy_half_to_float(npy_half h) { |
| 280 | + return BitCast<float>(ToFloatBits(h.value)); |
| 281 | +} |
| 282 | + |
| 283 | +template <> |
| 284 | +npy_half atomic_fadd<npy_half>(npy_half *loc, npy_half value, |
| 285 | + std::memory_order order) { |
| 286 | + npy_half old_value; |
| 287 | + |
| 288 | + const std::lock_guard<std::mutex> lock(atomic_op_guard); |
| 289 | + old_value = *loc; |
| 290 | + *loc = npy_float_to_half(npy_half_to_float(old_value) + |
| 291 | + npy_half_to_float(value)); |
| 292 | + |
| 293 | + return old_value; |
| 294 | +} |
| 295 | + |
86 | 296 | class AtomicOp { |
87 | 297 | public: |
88 | 298 | AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) |
@@ -370,6 +580,15 @@ template <RMWOp Op> struct OpCreator { |
370 | 580 | } |
371 | 581 | }; |
372 | 582 |
|
| 583 | +template <> template <> void OpCreator<RMWOp::FADD>::create<npy_half>() { |
| 584 | + if (!atomic_op && dtype.char_() == 'e') { // float16 |
| 585 | + // workaround until https://github.com/pybind/pybind11/issues/4061 is |
| 586 | + // implemented |
| 587 | + atomic_op = std::make_unique<AtomicRMWOp<npy_half, RMWOp::FADD>>( |
| 588 | + ptr, val, ret, mask, numel, order); |
| 589 | + } |
| 590 | +}; |
| 591 | + |
373 | 592 | template <RMWOp Op, typename... SupportedDTypes> |
374 | 593 | std::unique_ptr<AtomicOp> |
375 | 594 | makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, |
@@ -476,7 +695,7 @@ void init_triton_interpreter(py::module &&m) { |
476 | 695 |
|
477 | 696 | switch (rmw_op) { |
478 | 697 | MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) |
479 | | - MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) |
| 698 | + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double) |
480 | 699 | MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) |
481 | 700 | MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) |
482 | 701 | MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) |
|
0 commit comments