Skip to content

Commit 61eb94e

Browse files
authored
[INTERPRETER] Make sure interpreter works with float16 by reusing NumPy HALF-related code (#5010)
Closes #4992
1 parent 23c9ec1 commit 61eb94e

File tree

2 files changed

+220
-3
lines changed

2 files changed

+220
-3
lines changed

python/src/interpreter.cc

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
#include <mutex>
66
#include <pybind11/numpy.h>
77
#include <pybind11/pybind11.h>
8+
#include <stdexcept>
89
#include <type_traits>
910

1011
namespace py = pybind11;
1112

1213
namespace {
1314

15+
struct npy_half {
16+
uint16_t value;
17+
};
18+
1419
enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED };
1520

1621
std::mutex atomic_op_guard;
@@ -83,6 +88,211 @@ template <typename T> T atomic_fadd(T *loc, T value, std::memory_order order) {
8388
return old_value;
8489
}
8590

91+
/** Create a value of type `To` from the bits of `from`.
92+
*
93+
* similar to `std::bit_cast` but compatible with C++17,
94+
* should perform similar to `*reinterpret_cast<To*>(&from)`
95+
* or through punning without expecting any undefined behaviors.
96+
*
97+
* Note: taken from
98+
* https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32
99+
* with simplification.
100+
*/
101+
template <typename To, typename From>
102+
inline To BitCast(const From &from) noexcept {
103+
static_assert(sizeof(To) == sizeof(From),
104+
"both data types must have the same size");
105+
106+
static_assert(std::is_trivially_copyable_v<To> &&
107+
std::is_trivially_copyable_v<From>,
108+
"both data types must be trivially copyable");
109+
110+
To to;
111+
memcpy(&to, &from, sizeof(from));
112+
return to;
113+
}
114+
115+
// Taken from
116+
// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14
117+
template <bool gen_overflow = true, bool gen_underflow = true,
118+
bool round_even = true>
119+
inline uint16_t FromFloatBits(uint32_t f) {
120+
uint32_t f_exp, f_sig;
121+
uint16_t h_sgn, h_exp, h_sig;
122+
123+
h_sgn = (uint16_t)((f & 0x80000000u) >> 16);
124+
f_exp = (f & 0x7f800000u);
125+
126+
/* Exponent overflow/NaN converts to signed inf/NaN */
127+
if (f_exp >= 0x47800000u) {
128+
if (f_exp == 0x7f800000u) {
129+
/* Inf or NaN */
130+
f_sig = (f & 0x007fffffu);
131+
if (f_sig != 0) {
132+
/* NaN - propagate the flag in the significand... */
133+
uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13));
134+
/* ...but make sure it stays a NaN */
135+
if (ret == 0x7c00u) {
136+
ret++;
137+
}
138+
return h_sgn + ret;
139+
} else {
140+
/* signed inf */
141+
return (uint16_t)(h_sgn + 0x7c00u);
142+
}
143+
} else {
144+
if constexpr (gen_overflow) {
145+
// FloatStatus::RaiseOverflow();
146+
throw std::overflow_error("overflow to signed inf");
147+
}
148+
return (uint16_t)(h_sgn + 0x7c00u);
149+
}
150+
}
151+
152+
/* Exponent underflow converts to a subnormal half or signed zero */
153+
if (f_exp <= 0x38000000u) {
154+
/*
155+
* Signed zeros, subnormal floats, and floats with small
156+
* exponents all convert to signed zero half-floats.
157+
*/
158+
if (f_exp < 0x33000000u) {
159+
if constexpr (gen_underflow) {
160+
/* If f != 0, it underflowed to 0 */
161+
if ((f & 0x7fffffff) != 0) {
162+
// FloatStatus::RaiseUnderflow();
163+
throw std::underflow_error("");
164+
}
165+
}
166+
return h_sgn;
167+
}
168+
/* Make the subnormal significand */
169+
f_exp >>= 23;
170+
f_sig = (0x00800000u + (f & 0x007fffffu));
171+
if constexpr (gen_underflow) {
172+
/* If it's not exactly represented, it underflowed */
173+
if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) {
174+
// FloatStatus::RaiseUnderflow();
175+
throw std::underflow_error("");
176+
}
177+
}
178+
/*
179+
* Usually the significand is shifted by 13. For subnormals an
180+
* additional shift needs to occur. This shift is one for the largest
181+
* exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which
182+
* offsets the new first bit. At most the shift can be 1+10 bits.
183+
*/
184+
f_sig >>= (113 - f_exp);
185+
/* Handle rounding by adding 1 to the bit beyond half precision */
186+
if constexpr (round_even) {
187+
/*
188+
* If the last bit in the half significand is 0 (already even), and
189+
* the remaining bit pattern is 1000...0, then we do not add one
190+
* to the bit after the half significand. However, the (113 - f_exp)
191+
* shift can lose up to 11 bits, so the || checks them in the original.
192+
* In all other cases, we can just add one.
193+
*/
194+
if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) {
195+
f_sig += 0x00001000u;
196+
}
197+
} else {
198+
f_sig += 0x00001000u;
199+
}
200+
h_sig = (uint16_t)(f_sig >> 13);
201+
/*
202+
* If the rounding causes a bit to spill into h_exp, it will
203+
* increment h_exp from zero to one and h_sig will be zero.
204+
* This is the correct result.
205+
*/
206+
return (uint16_t)(h_sgn + h_sig);
207+
}
208+
209+
/* Regular case with no overflow or underflow */
210+
h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13);
211+
/* Handle rounding by adding 1 to the bit beyond half precision */
212+
f_sig = (f & 0x007fffffu);
213+
if constexpr (round_even) {
214+
/*
215+
* If the last bit in the half significand is 0 (already even), and
216+
* the remaining bit pattern is 1000...0, then we do not add one
217+
* to the bit after the half significand. In all other cases, we do.
218+
*/
219+
if ((f_sig & 0x00003fffu) != 0x00001000u) {
220+
f_sig += 0x00001000u;
221+
}
222+
} else {
223+
f_sig += 0x00001000u;
224+
}
225+
h_sig = (uint16_t)(f_sig >> 13);
226+
/*
227+
* If the rounding causes a bit to spill into h_exp, it will
228+
* increment h_exp by one and h_sig will be zero. This is the
229+
* correct result. h_exp may increment to 15, at greatest, in
230+
* which case the result overflows to a signed inf.
231+
*/
232+
if constexpr (gen_overflow) {
233+
h_sig += h_exp;
234+
if (h_sig == 0x7c00u) {
235+
// FloatStatus::RaiseOverflow();
236+
throw std::overflow_error("");
237+
}
238+
return h_sgn + h_sig;
239+
} else {
240+
return h_sgn + h_exp + h_sig;
241+
}
242+
}
243+
244+
// Taken from
245+
// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269
246+
constexpr uint32_t ToFloatBits(uint16_t h) {
247+
uint16_t h_exp = (h & 0x7c00u);
248+
uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16;
249+
switch (h_exp) {
250+
case 0x0000u: { // 0 or subnormal
251+
uint16_t h_sig = (h & 0x03ffu);
252+
// Signed zero
253+
if (h_sig == 0) {
254+
return f_sgn;
255+
}
256+
// Subnormal
257+
h_sig <<= 1;
258+
while ((h_sig & 0x0400u) == 0) {
259+
h_sig <<= 1;
260+
h_exp++;
261+
}
262+
uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23;
263+
uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13;
264+
return f_sgn + f_exp + f_sig;
265+
}
266+
case 0x7c00u: // inf or NaN
267+
// All-ones exponent and a copy of the significand
268+
return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13);
269+
default: // normalized
270+
// Just need to adjust the exponent and shift
271+
return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13);
272+
}
273+
}
274+
275+
npy_half npy_float_to_half(float f) {
276+
return {FromFloatBits(BitCast<uint32_t>(f))};
277+
}
278+
279+
float npy_half_to_float(npy_half h) {
280+
return BitCast<float>(ToFloatBits(h.value));
281+
}
282+
283+
template <>
284+
npy_half atomic_fadd<npy_half>(npy_half *loc, npy_half value,
285+
std::memory_order order) {
286+
npy_half old_value;
287+
288+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
289+
old_value = *loc;
290+
*loc = npy_float_to_half(npy_half_to_float(old_value) +
291+
npy_half_to_float(value));
292+
293+
return old_value;
294+
}
295+
86296
class AtomicOp {
87297
public:
88298
AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order)
@@ -370,6 +580,15 @@ template <RMWOp Op> struct OpCreator {
370580
}
371581
};
372582

583+
template <> template <> void OpCreator<RMWOp::FADD>::create<npy_half>() {
584+
if (!atomic_op && dtype.char_() == 'e') { // float16
585+
// workaround until https://github.com/pybind/pybind11/issues/4061 is
586+
// implemented
587+
atomic_op = std::make_unique<AtomicRMWOp<npy_half, RMWOp::FADD>>(
588+
ptr, val, ret, mask, numel, order);
589+
}
590+
};
591+
373592
template <RMWOp Op, typename... SupportedDTypes>
374593
std::unique_ptr<AtomicOp>
375594
makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val,
@@ -476,7 +695,7 @@ void init_triton_interpreter(py::module &&m) {
476695

477696
switch (rmw_op) {
478697
MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t)
479-
MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double)
698+
MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double)
480699
MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t)
481700
MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t)
482701
MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t)

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,8 +1458,6 @@ def kernel(X):
14581458
for num_ctas in num_ctas_list
14591459
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
14601460
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
1461-
if is_interpreter() and dtype_x_str == 'float16':
1462-
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14631461
shape0, shape1 = shape
14641462
# triton kernel
14651463

0 commit comments

Comments
 (0)