Skip to content

Commit 13ebbd5

Browse files
authored
Add IsNan and IsInf ops (#937)
* Add IsNan op * Add IsInf op * Remove leftover * Line implementation * Add polyfill w/ flexible bit masking * More flexible wgsl * Cleanup * Fix u64 * Fix clippy * Fix no-std targets w/o pointer-sized atomics * Fix const sanitize
1 parent 68dca01 commit 13ebbd5

File tree

24 files changed

+536
-7
lines changed

24 files changed

+536
-7
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ tracel-xtask = { version = "=2.1.8" }
8989
portable-atomic = { version = "1.10", default-features = false, features = [
9090
"serde",
9191
] }
92+
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
9293
pretty_assertions = "1.4"
9394

9495
# Async

crates/cubecl-common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ spin = { workspace = true, features = ["mutex", "spin_mutex"] }
6262

6363
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
6464
portable-atomic = { workspace = true }
65+
portable-atomic-util = { workspace = true }
6566
spin = { workspace = true, features = [
6667
"mutex",
6768
"spin_mutex",

crates/cubecl-common/src/map.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
use crate::stub::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
2+
3+
#[cfg(target_has_atomic = "ptr")]
24
use alloc::sync::Arc;
5+
6+
#[cfg(not(target_has_atomic = "ptr"))]
7+
use portable_atomic_util::Arc;
8+
39
use hashbrown::HashMap;
410

511
/// A thread-safe map that allows concurrent access to values using read-write locks.

crates/cubecl-core/src/frontend/container/line/ops.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use num_traits::{NumCast, ToPrimitive};
44

55
use crate::{
66
self as cubecl,
7-
prelude::{Powi, SaturatingAdd, SaturatingSub},
7+
prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub},
88
};
99
use crate::{
1010
frontend::{
@@ -259,6 +259,8 @@ impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
259259
impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
260260
impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
261261
impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
262+
impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
263+
impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
262264

263265
#[cube]
264266
impl<P: CountOnes> Line<P> {

crates/cubecl-core/src/frontend/element/float.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ pub trait Float:
3737
+ Magnitude
3838
+ Normalize
3939
+ Dot
40+
+ IsNan
41+
+ IsInf
4042
+ Into<Self::ExpandType>
4143
+ core::ops::Neg<Output = Self>
4244
+ core::ops::Add<Output = Self>

crates/cubecl-core/src/frontend/element/float/typemap.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ impl<const POS: u8> Sqrt for ElemExpand<POS> {}
250250
impl<const POS: u8> Round for ElemExpand<POS> {}
251251
impl<const POS: u8> Floor for ElemExpand<POS> {}
252252
impl<const POS: u8> Ceil for ElemExpand<POS> {}
253+
impl<const POS: u8> IsNan for ElemExpand<POS> {}
254+
impl<const POS: u8> IsInf for ElemExpand<POS> {}
253255

254256
impl<const POS: u8> Float for ElemExpand<POS> {
255257
const DIGITS: u32 = 32;

crates/cubecl-core/src/frontend/operation/cmp.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use crate::frontend::operation::base::cmp_expand;
33
use crate::ir::{Comparison, Scope};
44
use crate::prelude::CubePrimitive;
55

6+
// NOTE: Unary comparison tests are in the unary module
7+
68
pub mod ne {
79
use super::*;
810

crates/cubecl-core/src/frontend/operation/unary.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2-
use cubecl_ir::{Bitwise, Operator, Type};
2+
use cubecl_ir::{Bitwise, Comparison, Operator, Type};
33
use half::{bf16, f16};
44

55
use crate::{
@@ -358,3 +358,29 @@ impl_unary_func_fixed_out_ty!(
358358
u64,
359359
i64
360360
);
361+
impl_unary_func_fixed_out_ty!(
362+
IsNan,
363+
is_nan,
364+
__expand_is_nan,
365+
bool,
366+
Comparison::IsNan,
367+
f16,
368+
bf16,
369+
flex32,
370+
tf32,
371+
f32,
372+
f64
373+
);
374+
impl_unary_func_fixed_out_ty!(
375+
IsInf,
376+
is_inf,
377+
__expand_is_inf,
378+
bool,
379+
Comparison::IsInf,
380+
f16,
381+
bf16,
382+
flex32,
383+
tf32,
384+
f32,
385+
f64
386+
);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod checked_io;
2+
pub mod predicate;
23
pub mod saturating;
34
pub mod unroll;
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use core::{f32, f64};
2+
3+
use crate as cubecl;
4+
use cubecl_ir::{
5+
Allocator, Comparison, ElemType, ExpandElement, FloatKind, Instruction, Operation, Processor,
6+
Scope, ScopeProcessing, UIntKind, Variable,
7+
};
8+
use half::{bf16, f16};
9+
10+
use crate::prelude::*;
11+
12+
#[derive(Debug, Default)]
13+
pub struct PredicateProcessor;
14+
15+
impl Processor for PredicateProcessor {
16+
fn transform(
17+
&self,
18+
mut processing: cubecl_ir::ScopeProcessing,
19+
allocator: Allocator,
20+
) -> cubecl_ir::ScopeProcessing {
21+
let mut instructions = Vec::new();
22+
core::mem::swap(&mut processing.instructions, &mut instructions);
23+
24+
for instruction in instructions {
25+
if let Operation::Comparison(comparison) = &instruction.operation {
26+
match comparison {
27+
Comparison::IsNan(op) => {
28+
run_polyfill(
29+
&mut processing,
30+
op.input,
31+
instruction.out(),
32+
&allocator,
33+
is_nan::expand::<FloatExpand<0>, IntExpand<1>>,
34+
);
35+
continue;
36+
}
37+
Comparison::IsInf(op) => {
38+
run_polyfill(
39+
&mut processing,
40+
op.input,
41+
instruction.out(),
42+
&allocator,
43+
is_inf::expand::<FloatExpand<0>, IntExpand<1>>,
44+
);
45+
continue;
46+
}
47+
_ => {}
48+
}
49+
}
50+
processing.instructions.push(instruction);
51+
}
52+
processing
53+
}
54+
}
55+
56+
fn run_polyfill<T: CubePrimitive, O: CubePrimitive>(
57+
processing: &mut ScopeProcessing,
58+
input: Variable,
59+
out: Variable,
60+
allocator: &Allocator,
61+
mut polyfill: impl FnMut(&mut Scope, ExpandElementTyped<T>, u32, u32) -> ExpandElementTyped<O>,
62+
) {
63+
let input = ExpandElement::Plain(input);
64+
let mut scope = Scope::root(false).with_allocator(allocator.clone());
65+
scope.register_type::<FloatExpand<0>>(input.storage_type());
66+
67+
let out_poly = if let ElemType::Float(kind) = input.elem_type() {
68+
let (unsigned_ty, bit_width, mantissa_bits) = match kind {
69+
FloatKind::F64 => (
70+
UIntKind::U64,
71+
f64::size_bits().unwrap(),
72+
f64::MANTISSA_DIGITS - 1,
73+
),
74+
FloatKind::F32 => (
75+
UIntKind::U32,
76+
f32::size_bits().unwrap(),
77+
f32::MANTISSA_DIGITS - 1,
78+
),
79+
FloatKind::F16 => (
80+
UIntKind::U16,
81+
f16::size_bits().unwrap(),
82+
f16::MANTISSA_DIGITS - 1,
83+
),
84+
FloatKind::BF16 => (
85+
UIntKind::U16,
86+
bf16::size_bits().unwrap(),
87+
bf16::MANTISSA_DIGITS - 1,
88+
),
89+
_ => unreachable!(),
90+
};
91+
scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into());
92+
93+
let exp_bits = bit_width as u32 - mantissa_bits - 1;
94+
95+
polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand
96+
} else {
97+
panic!("Should be float")
98+
};
99+
100+
let tmp_processing = scope.process([]);
101+
102+
processing.instructions.extend(tmp_processing.instructions);
103+
processing.variables.extend(tmp_processing.variables);
104+
105+
processing
106+
.instructions
107+
.push(Instruction::new(Operation::Copy(*out_poly), out));
108+
}
109+
110+
#[cube]
111+
fn is_nan<F: Float, U: Int>(
112+
x: Line<F>,
113+
#[comptime] mantissa_bits: u32,
114+
#[comptime] exp_bits: u32,
115+
) -> Line<bool> {
116+
// Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
117+
let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
118+
let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
119+
120+
let bits: Line<U> = Line::<U>::reinterpret(x);
121+
122+
let abs_bits = bits & Line::new(U::cast_from(abs_mask));
123+
124+
abs_bits.greater_than(Line::new(U::cast_from(inf_bits)))
125+
}
126+
127+
// Same trick as NaN detection following IEEE 754, but check for all 0 bits equality
128+
#[cube]
129+
fn is_inf<F: Float, U: Int>(
130+
x: Line<F>,
131+
#[comptime] mantissa_bits: u32,
132+
#[comptime] exp_bits: u32,
133+
) -> Line<bool> {
134+
// Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
135+
let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
136+
let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
137+
138+
let bits: Line<U> = Line::<U>::reinterpret(x);
139+
140+
let abs_bits = bits & Line::new(U::cast_from(abs_mask));
141+
142+
abs_bits.equal(Line::new(U::cast_from(inf_bits)))
143+
}

0 commit comments

Comments
 (0)