|
| 1 | +use core::{f32, f64}; |
| 2 | + |
| 3 | +use crate as cubecl; |
| 4 | +use cubecl_ir::{ |
| 5 | + Allocator, Comparison, ElemType, ExpandElement, FloatKind, Instruction, Operation, Processor, |
| 6 | + Scope, ScopeProcessing, UIntKind, Variable, |
| 7 | +}; |
| 8 | +use half::{bf16, f16}; |
| 9 | + |
| 10 | +use crate::prelude::*; |
| 11 | + |
| 12 | +#[derive(Debug, Default)] |
| 13 | +pub struct PredicateProcessor; |
| 14 | + |
| 15 | +impl Processor for PredicateProcessor { |
| 16 | + fn transform( |
| 17 | + &self, |
| 18 | + mut processing: cubecl_ir::ScopeProcessing, |
| 19 | + allocator: Allocator, |
| 20 | + ) -> cubecl_ir::ScopeProcessing { |
| 21 | + let mut instructions = Vec::new(); |
| 22 | + core::mem::swap(&mut processing.instructions, &mut instructions); |
| 23 | + |
| 24 | + for instruction in instructions { |
| 25 | + if let Operation::Comparison(comparison) = &instruction.operation { |
| 26 | + match comparison { |
| 27 | + Comparison::IsNan(op) => { |
| 28 | + run_polyfill( |
| 29 | + &mut processing, |
| 30 | + op.input, |
| 31 | + instruction.out(), |
| 32 | + &allocator, |
| 33 | + is_nan::expand::<FloatExpand<0>, IntExpand<1>>, |
| 34 | + ); |
| 35 | + continue; |
| 36 | + } |
| 37 | + Comparison::IsInf(op) => { |
| 38 | + run_polyfill( |
| 39 | + &mut processing, |
| 40 | + op.input, |
| 41 | + instruction.out(), |
| 42 | + &allocator, |
| 43 | + is_inf::expand::<FloatExpand<0>, IntExpand<1>>, |
| 44 | + ); |
| 45 | + continue; |
| 46 | + } |
| 47 | + _ => {} |
| 48 | + } |
| 49 | + } |
| 50 | + processing.instructions.push(instruction); |
| 51 | + } |
| 52 | + processing |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +fn run_polyfill<T: CubePrimitive, O: CubePrimitive>( |
| 57 | + processing: &mut ScopeProcessing, |
| 58 | + input: Variable, |
| 59 | + out: Variable, |
| 60 | + allocator: &Allocator, |
| 61 | + mut polyfill: impl FnMut(&mut Scope, ExpandElementTyped<T>, u32, u32) -> ExpandElementTyped<O>, |
| 62 | +) { |
| 63 | + let input = ExpandElement::Plain(input); |
| 64 | + let mut scope = Scope::root(false).with_allocator(allocator.clone()); |
| 65 | + scope.register_type::<FloatExpand<0>>(input.storage_type()); |
| 66 | + |
| 67 | + let out_poly = if let ElemType::Float(kind) = input.elem_type() { |
| 68 | + let (unsigned_ty, bit_width, mantissa_bits) = match kind { |
| 69 | + FloatKind::F64 => ( |
| 70 | + UIntKind::U64, |
| 71 | + f64::size_bits().unwrap(), |
| 72 | + f64::MANTISSA_DIGITS - 1, |
| 73 | + ), |
| 74 | + FloatKind::F32 => ( |
| 75 | + UIntKind::U32, |
| 76 | + f32::size_bits().unwrap(), |
| 77 | + f32::MANTISSA_DIGITS - 1, |
| 78 | + ), |
| 79 | + FloatKind::F16 => ( |
| 80 | + UIntKind::U16, |
| 81 | + f16::size_bits().unwrap(), |
| 82 | + f16::MANTISSA_DIGITS - 1, |
| 83 | + ), |
| 84 | + FloatKind::BF16 => ( |
| 85 | + UIntKind::U16, |
| 86 | + bf16::size_bits().unwrap(), |
| 87 | + bf16::MANTISSA_DIGITS - 1, |
| 88 | + ), |
| 89 | + _ => unreachable!(), |
| 90 | + }; |
| 91 | + scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into()); |
| 92 | + |
| 93 | + let exp_bits = bit_width as u32 - mantissa_bits - 1; |
| 94 | + |
| 95 | + polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand |
| 96 | + } else { |
| 97 | + panic!("Should be float") |
| 98 | + }; |
| 99 | + |
| 100 | + let tmp_processing = scope.process([]); |
| 101 | + |
| 102 | + processing.instructions.extend(tmp_processing.instructions); |
| 103 | + processing.variables.extend(tmp_processing.variables); |
| 104 | + |
| 105 | + processing |
| 106 | + .instructions |
| 107 | + .push(Instruction::new(Operation::Copy(*out_poly), out)); |
| 108 | +} |
| 109 | + |
| 110 | +#[cube] |
| 111 | +fn is_nan<F: Float, U: Int>( |
| 112 | + x: Line<F>, |
| 113 | + #[comptime] mantissa_bits: u32, |
| 114 | + #[comptime] exp_bits: u32, |
| 115 | +) -> Line<bool> { |
| 116 | + // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64 |
| 117 | + let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64]; |
| 118 | + let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64]; |
| 119 | + |
| 120 | + let bits: Line<U> = Line::<U>::reinterpret(x); |
| 121 | + |
| 122 | + let abs_bits = bits & Line::new(U::cast_from(abs_mask)); |
| 123 | + |
| 124 | + abs_bits.greater_than(Line::new(U::cast_from(inf_bits))) |
| 125 | +} |
| 126 | + |
| 127 | +// Same trick as NaN detection following IEEE 754, but check for all 0 bits equality |
| 128 | +#[cube] |
| 129 | +fn is_inf<F: Float, U: Int>( |
| 130 | + x: Line<F>, |
| 131 | + #[comptime] mantissa_bits: u32, |
| 132 | + #[comptime] exp_bits: u32, |
| 133 | +) -> Line<bool> { |
| 134 | + // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64 |
| 135 | + let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64]; |
| 136 | + let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64]; |
| 137 | + |
| 138 | + let bits: Line<U> = Line::<U>::reinterpret(x); |
| 139 | + |
| 140 | + let abs_bits = bits & Line::new(U::cast_from(abs_mask)); |
| 141 | + |
| 142 | + abs_bits.equal(Line::new(U::cast_from(inf_bits))) |
| 143 | +} |
0 commit comments