|
| 1 | +use cubecl::prelude::*; |
| 2 | +use cubecl_common::quant::scheme::*; |
| 3 | +use cubecl_common::{e2m1x2, e4m3, e5m2}; |
| 4 | +use cubecl_core::{self as cubecl, intrinsic}; |
| 5 | + |
| 6 | +/// Dequantize a line of values, where `line_size * num_quants` is a power of two. |
| 7 | +/// Unaligned values can't be dequantized in place. |
| 8 | +#[cube] |
| 9 | +pub fn dequantize_aligned<Q: CubePrimitive, S: CubePrimitive, F: Float>( |
| 10 | + value: Line<Q>, |
| 11 | + scale: S, |
| 12 | + #[comptime] scheme: QuantScheme, |
| 13 | +) -> Line<F> { |
| 14 | + let q_values = match scheme.store { |
| 15 | + QuantStore::Native => Line::<F>::cast_from(value), |
| 16 | + QuantStore::U32 => unpack_cast_u32::<F>(Line::cast_from(value), scheme), |
| 17 | + }; |
| 18 | + let scale = Line::<F>::cast_from(scale); |
| 19 | + |
| 20 | + match scheme.mode { |
| 21 | + QuantMode::Symmetric => q_values * scale, |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +/// Unpack a set of values from u32, and convert to the specified floating point format. |
| 26 | +#[cube] |
| 27 | +pub fn unpack_cast_u32<F: Float>(value: Line<u32>, #[comptime] scheme: QuantScheme) -> Line<F> { |
| 28 | + let num_quants = comptime![scheme.num_quants() as u32]; |
| 29 | + let native_packing = comptime![scheme.native_packing() as u32]; |
| 30 | + let out_line_size = comptime![value.line_size() * num_quants]; |
| 31 | + let size_bits = comptime![scheme.size_bits_value() as u32]; |
| 32 | + let mask = comptime![packing_mask(scheme)]; |
| 33 | + |
| 34 | + let mut out = Line::<F>::empty(out_line_size); |
| 35 | + |
| 36 | + #[unroll] |
| 37 | + for line_idx in 0..value.line_size() { |
| 38 | + let line_idx = unwrap(line_idx); |
| 39 | + let packed_val = value[line_idx]; |
| 40 | + let out_offset = comptime![line_idx * num_quants]; |
| 41 | + #[unroll] |
| 42 | + for packed_idx in range_stepped(0, num_quants, native_packing) { |
| 43 | + let packed_idx = unwrap(packed_idx); |
| 44 | + let shift = packed_idx * size_bits; |
| 45 | + let value = (packed_val >> shift) & mask; |
| 46 | + |
| 47 | + let float_value = cast_masked::<F>(value, scheme); |
| 48 | + |
| 49 | + #[unroll] |
| 50 | + for native_idx in 0..native_packing { |
| 51 | + let native_idx = unwrap(native_idx); |
| 52 | + let out_offset = comptime![out_offset + packed_idx + native_idx]; |
| 53 | + out[out_offset] = float_value[native_idx]; |
| 54 | + } |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + out |
| 59 | +} |
| 60 | + |
| 61 | +/// The mask required for each packed value, taking into account the native packing required for |
| 62 | +/// `e2m1`. |
| 63 | +fn packing_mask(scheme: QuantScheme) -> u32 { |
| 64 | + let bits = match scheme.value { |
| 65 | + QuantValue::E2M1 => 8, // Packed conversion |
| 66 | + other => other.size_bits(), |
| 67 | + }; |
| 68 | + (1u32 << bits) - 1 |
| 69 | +} |
| 70 | + |
| 71 | +/// Cast a masked-out value in the low `n` bits of a `u32` to the specified float type. |
| 72 | +/// Applies sign conversion for integer quantization before casting to the float type, |
| 73 | +/// while minifloats are simply truncated to `u8`, reinterpreted and then cast. |
| 74 | +/// For `e2m1`, casting is done on the packed `e2m1x2` representation. |
| 75 | +/// |
| 76 | +/// # Returns |
| 77 | +/// Two floating point numbers for `e2m1`, one for all other formats. |
| 78 | +#[cube] |
| 79 | +fn cast_masked<F: Float>(value: u32, #[comptime] scheme: QuantScheme) -> Line<F> { |
| 80 | + match scheme.value { |
| 81 | + // For minifloat we can assume if they're supported then u8 is supported |
| 82 | + QuantValue::E5M2 => Line::<F>::cast_from(e5m2::reinterpret(value as u8)), |
| 83 | + QuantValue::E4M3 => Line::<F>::cast_from(e4m3::reinterpret(value as u8)), |
| 84 | + QuantValue::E2M1 => Line::<F>::cast_from(e2m1x2::reinterpret(value as u8)), |
| 85 | + QuantValue::Q8F |
| 86 | + | QuantValue::Q4F |
| 87 | + | QuantValue::Q2F |
| 88 | + | QuantValue::Q8S |
| 89 | + | QuantValue::Q4S |
| 90 | + | QuantValue::Q2S => { |
| 91 | + let size_quant = comptime!(scheme.size_bits_value() as u32); |
| 92 | + let sign_bit = comptime!(1u32 << (size_quant - 1)); |
| 93 | + let two_pow_n = comptime!(1 << size_quant); |
| 94 | + |
| 95 | + // Branchless two's complement conversion |
| 96 | + // If raw >= 2^(n-1), then result = raw - 2^n |
| 97 | + let raw_i32 = value as i32; |
| 98 | + let is_negative = (value >= sign_bit) as i32; // 1 if negative, 0 if positive |
| 99 | + let signed_value = raw_i32 - (is_negative * two_pow_n); |
| 100 | + Line::<F>::cast_from(signed_value) |
| 101 | + } |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +#[allow(unused_variables)] |
| 106 | +#[cube] |
| 107 | +pub(crate) fn unwrap(v: u32) -> comptime_type!(u32) { |
| 108 | + intrinsic!(|_| v.constant().expect("Must be constant").as_u32()) |
| 109 | +} |
0 commit comments