Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub trait Float:
+ ArcTan2
+ Powf
+ Powi<i32>
+ Hypot
+ Rhypot
+ Sqrt
+ InverseSqrt
+ Round
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/element/float/typemap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ impl<const POS: u8> Radians for ElemExpand<POS> {}
impl<const POS: u8> ArcTan2 for ElemExpand<POS> {}
impl<const POS: u8> Powf for ElemExpand<POS> {}
impl<const POS: u8, I: CubePrimitive> Powi<I> for ElemExpand<POS> {}
impl<const POS: u8> Hypot for ElemExpand<POS> {}
impl<const POS: u8> Rhypot for ElemExpand<POS> {}
impl<const POS: u8> Sqrt for ElemExpand<POS> {}
impl<const POS: u8> InverseSqrt for ElemExpand<POS> {}
impl<const POS: u8> Round for ElemExpand<POS> {}
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod options;
mod plane;
mod polyfills;
mod topology;
mod trigonometry;
mod validation;

pub use branch::{RangeExpand, SteppedRangeExpand, range, range_stepped};
Expand All @@ -30,6 +31,7 @@ pub use options::*;
pub use plane::*;
pub use polyfills::*;
pub use topology::*;
pub use trigonometry::*;
pub use validation::*;

pub use crate::{debug_print, debug_print_expand};
25 changes: 25 additions & 0 deletions crates/cubecl-core/src/frontend/operation/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,31 @@ impl_binary_func!(
f32,
f64
);

impl_binary_func!(
Hypot,
hypot,
Arithmetic::Hypot,
f16,
bf16,
flex32,
tf32,
f32,
f64
);

impl_binary_func!(
Rhypot,
rhypot,
Arithmetic::Rhypot,
f16,
bf16,
flex32,
tf32,
f32,
f64
);

impl_binary_func!(
ArcTan2,
atan2,
Expand Down
62 changes: 62 additions & 0 deletions crates/cubecl-core/src/frontend/trigonometry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use cubecl_ir::{ExpandElement, Variable};

use crate::prelude::*;
use crate::{self as cubecl};

/// Computes the hypotenuse of a right triangle given the lengths of the other two sides.
///
/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids
/// overflow and underflow issues.
#[cube]
pub fn hypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
let one = Line::empty(lhs.size()).fill(F::from_int(1));
let a = Abs::abs(lhs);
let b = Abs::abs(rhs);
let max_val = Max::max(a, b);
let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
let max_val_safe = select_many(max_val_is_zero, one, max_val);
let min_val = Min::min(a, b);
let t = min_val / max_val_safe;

max_val * Sqrt::sqrt(fma(t, t, one))
}

#[allow(missing_docs)]
pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
let res = hypot::expand::<FloatExpand<0>>(
scope,
ExpandElement::Plain(lhs).into(),
ExpandElement::Plain(rhs).into(),
);
assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
}

/// Computes the reciprocal of the hypotenuse of a right triangle given the lengths of the other two sides.
///
/// This function computes `1 / sqrt(x² + y²)` in a numerically stable way that avoids
/// overflow and underflow issues.
#[cube]
pub fn rhypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
let one = Line::empty(lhs.size()).fill(F::from_int(1));
let a = Abs::abs(lhs);
let b = Abs::abs(rhs);
let max_val = Max::max(a, b);
let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
let max_val_safe = select_many(max_val_is_zero, one, max_val);
let min_val = Min::min(a, b);
let t = min_val / max_val_safe;

InverseSqrt::inverse_sqrt(fma(t, t, one)) / max_val
}

#[allow(missing_docs)]
pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
let res = rhypot::expand::<FloatExpand<0>>(
scope,
ExpandElement::Plain(lhs).into(),
ExpandElement::Plain(rhs).into(),
);
assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
}
68 changes: 67 additions & 1 deletion crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(clippy::approx_constant)]

use core::f32;

use std::{fmt::Display, sync::LazyLock};

use crate::{self as cubecl, as_type};
Expand Down Expand Up @@ -32,7 +34,11 @@ pub(crate) fn assert_equals_approx<
// account for lower precision at higher values
let allowed_error = F::new((epsilon * e.to_f32().unwrap().abs()).max(epsilon));
assert!(
(*a - *e).abs() < allowed_error || (a.is_nan() && e.is_nan()),
(*a - *e).abs() < allowed_error
|| (a.is_nan() && e.is_nan())
|| (a.is_infinite()
&& e.is_infinite()
&& a.is_sign_positive() == e.is_sign_positive()),
"Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={}
index: {}
actual: {:?}
Expand Down Expand Up @@ -186,6 +192,64 @@ test_binary_impl!(
]
);

test_binary_impl!(
test_hypot,
F,
F::hypot,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: as_type![F: 3., 0., 5., 0.],
rhs: as_type![F: 4., 5., 0., 0.],
expected: as_type![F: 5., 5., 5., 0.]
},
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 3., 0., 5., 8.],
rhs: as_type![F: 4., 5., 0., 15.],
expected: as_type![F: 5., 5., 5., 17.]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: -3., 0., -5., -8.],
rhs: as_type![F: -4., -5., 0., 15.],
expected: as_type![F: 5., 5., 5., 17.]
}
]
);

test_binary_impl!(
test_rhypot,
F,
F::rhypot,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: as_type![F: 3., 0., 5., 0.],
rhs: as_type![F: 4., 5., 0., 0.],
expected: &[F::new(0.2), F::new(0.2), F::new(0.2), F::INFINITY]
},
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 3., 0., 5., 0.3],
rhs: as_type![F: 4., 5., 0., 0.4],
expected: as_type![F: 0.2, 0.2, 0.2, 2.]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: 0., 0., -5., -0.3],
rhs: as_type![F: -1., -5., 0., -0.4],
expected: as_type![F: 1., 0.2, 0.2, 2.]
}
]
);

#[cube(launch_unchecked)]
fn test_powi_kernel<F: Float>(
lhs: &Array<Line<F>>,
Expand Down Expand Up @@ -356,6 +420,8 @@ macro_rules! testgen_binary {

add_test!(test_dot);
add_test!(test_powf);
add_test!(test_hypot);
add_test!(test_rhypot);
add_test!(test_powi);
add_test!(test_atan2);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ pub fn test_simple_1_expected() -> Vec<f32> {
// let lhs: Vec<f16> = (0..64).map(|i| f16::from_f32(i as f32)).collect();
// let rhs: Vec<f16> = (0..64).map(|i| f16::from_f32((i % 8) as f32)).collect();

// let lhs = client.create(f16::as_bytes(&lhs));
// let rhs = client.create(f16::as_bytes(&rhs));
// let lhs = client.create_from_slice(f16::as_bytes(&lhs));
// let rhs = client.create_from_slice(f16::as_bytes(&rhs));
// let out = client.empty(core::mem::size_of::<f16>() * 64);

// unsafe {
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-core/src/runtime_tests/unary.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![allow(clippy::approx_constant)]

use std::f32::consts::PI;
use std::fmt::Display;
use core::f32;
use core::f32::consts::PI;

use core::fmt::Display;

use crate::{self as cubecl, as_type};

Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,12 @@ impl<D: Dialect> CppCompiler<D> {
gpu::Arithmetic::Powi(op) => {
instructions.push(Instruction::Powi(self.compile_binary(op, out)))
}
gpu::Arithmetic::Hypot(op) => {
instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
}
gpu::Arithmetic::Rhypot(op) => {
instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
}
gpu::Arithmetic::Sqrt(op) => {
let op = self.compile_unary(op, out);
instructions.push(self.select_fast_float(
Expand Down
107 changes: 106 additions & 1 deletion crates/cubecl-cpp/src/shared/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ impl<D: Dialect> Binary<D> for Powi {
f.write_str("};\n")
}
}

pub struct ArcTan2;

impl<D: Dialect> Binary<D> for ArcTan2 {
Expand Down Expand Up @@ -370,6 +369,112 @@ impl<D: Dialect> Binary<D> for ArcTan2 {
}
}

pub struct Hypot;

impl<D: Dialect> Binary<D> for Hypot {
// Hypot doesn't support half and no half equivalent exists
fn format_scalar<Lhs, Rhs>(
f: &mut Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
item: Item<D>,
) -> std::fmt::Result
where
Lhs: Component<D>,
Rhs: Component<D>,
{
let elem = item.elem;
let lhs = lhs.to_string();
let rhs = rhs.to_string();
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
let lhs = format!("float({lhs})");
let rhs = format!("float({rhs})");
write!(f, "{elem}(")?;
D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
write!(f, ")")
}
_ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
}
}

// Hypot doesn't support half and no half equivalent exists
fn unroll_vec(
f: &mut Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let item_out = out.item();
let index = out.item().vectorization;

let out = out.fmt_left();
writeln!(f, "{out} = {item_out}{{")?;
for i in 0..index {
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);

Self::format_scalar(f, lhsi, rhsi, item_out)?;
f.write_str(", ")?;
}

f.write_str("};\n")
}
}

pub struct Rhypot;

impl<D: Dialect> Binary<D> for Rhypot {
// Rhypot doesn't support half and no half equivalent exists
fn format_scalar<Lhs, Rhs>(
f: &mut Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
item: Item<D>,
) -> std::fmt::Result
where
Lhs: Component<D>,
Rhs: Component<D>,
{
let elem = item.elem;
let lhs = lhs.to_string();
let rhs = rhs.to_string();
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
let lhs = format!("float({lhs})");
let rhs = format!("float({rhs})");
write!(f, "{elem}(")?;
D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
write!(f, ")")
}
_ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
}
}

// Rhypot doesn't support half and no half equivalent exists
fn unroll_vec(
f: &mut Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let item_out = out.item();
let index = out.item().vectorization;

let out = out.fmt_left();
writeln!(f, "{out} = {item_out}{{")?;
for i in 0..index {
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);

Self::format_scalar(f, lhsi, rhsi, item_out)?;
f.write_str(", ")?;
}

f.write_str("};\n")
}
}

pub struct Max;

impl<D: Dialect> Binary<D> for Max {
Expand Down
Loading