Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub trait Float:
+ ArcTan2
+ Powf
+ Powi<i32>
+ Hypot
+ Rhypot
+ Sqrt
+ InverseSqrt
+ Round
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/element/float/typemap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ impl<const POS: u8> Radians for ElemExpand<POS> {}
impl<const POS: u8> ArcTan2 for ElemExpand<POS> {}
impl<const POS: u8> Powf for ElemExpand<POS> {}
impl<const POS: u8, I: CubePrimitive> Powi<I> for ElemExpand<POS> {}
impl<const POS: u8> Hypot for ElemExpand<POS> {}
impl<const POS: u8> Rhypot for ElemExpand<POS> {}
impl<const POS: u8> Sqrt for ElemExpand<POS> {}
impl<const POS: u8> InverseSqrt for ElemExpand<POS> {}
impl<const POS: u8> Round for ElemExpand<POS> {}
Expand Down
25 changes: 25 additions & 0 deletions crates/cubecl-core/src/frontend/operation/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,31 @@ impl_binary_func!(
f32,
f64
);

impl_binary_func!(
Hypot,
hypot,
Arithmetic::Hypot,
f16,
bf16,
flex32,
tf32,
f32,
f64
);

impl_binary_func!(
Rhypot,
rhypot,
Arithmetic::Rhypot,
f16,
bf16,
flex32,
tf32,
f32,
f64
);

impl_binary_func!(
ArcTan2,
atan2,
Expand Down
68 changes: 67 additions & 1 deletion crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(clippy::approx_constant)]

use core::f32;

use std::{fmt::Display, sync::LazyLock};

use crate::{self as cubecl, as_type};
Expand Down Expand Up @@ -32,7 +34,11 @@ pub(crate) fn assert_equals_approx<
// account for lower precision at higher values
let allowed_error = F::new((epsilon * e.to_f32().unwrap().abs()).max(epsilon));
assert!(
(*a - *e).abs() < allowed_error || (a.is_nan() && e.is_nan()),
(*a - *e).abs() < allowed_error
|| (a.is_nan() && e.is_nan())
|| (a.is_infinite()
&& e.is_infinite()
&& a.is_sign_positive() == e.is_sign_positive()),
"Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={}
index: {}
actual: {:?}
Expand Down Expand Up @@ -186,6 +192,64 @@ test_binary_impl!(
]
);

test_binary_impl!(
test_hypot,
F,
F::hypot,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: as_type![F: 3., 0., 5., 0.],
rhs: as_type![F: 4., 5., 0., 0.],
expected: as_type![F: 5., 5., 5., 0.]
},
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 3., 0., 5., 8.],
rhs: as_type![F: 4., 5., 0., 15.],
expected: as_type![F: 5., 5., 5., 17.]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: -3., 0., -5., -8.],
rhs: as_type![F: -4., -5., 0., 15.],
expected: as_type![F: 5., 5., 5., 17.]
}
]
);

test_binary_impl!(
test_rhypot,
F,
F::rhypot,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: as_type![F: 3., 0., 5., 0.],
rhs: as_type![F: 4., 5., 0., 0.],
expected: &[F::new(0.2), F::new(0.2), F::new(0.2), F::INFINITY]
},
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 3., 0., 5., 0.3],
rhs: as_type![F: 4., 5., 0., 0.4],
expected: as_type![F: 0.2, 0.2, 0.2, 2.]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: 0., 0., -5., -0.3],
rhs: as_type![F: -1., -5., 0., -0.4],
expected: as_type![F: 1., 0.2, 0.2, 2.]
}
]
);

#[cube(launch_unchecked)]
fn test_powi_kernel<F: Float>(
lhs: &Array<Line<F>>,
Expand Down Expand Up @@ -356,6 +420,8 @@ macro_rules! testgen_binary {

add_test!(test_dot);
add_test!(test_powf);
add_test!(test_hypot);
add_test!(test_rhypot);
add_test!(test_powi);
add_test!(test_atan2);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ pub fn test_simple_1_expected() -> Vec<f32> {
// let lhs: Vec<f16> = (0..64).map(|i| f16::from_f32(i as f32)).collect();
// let rhs: Vec<f16> = (0..64).map(|i| f16::from_f32((i % 8) as f32)).collect();

// let lhs = client.create(f16::as_bytes(&lhs));
// let rhs = client.create(f16::as_bytes(&rhs));
// let lhs = client.create_from_slice(f16::as_bytes(&lhs));
// let rhs = client.create_from_slice(f16::as_bytes(&rhs));
// let out = client.empty(core::mem::size_of::<f16>() * 64);

// unsafe {
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-core/src/runtime_tests/unary.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![allow(clippy::approx_constant)]

use std::f32::consts::PI;
use std::fmt::Display;
use core::f32;
use core::f32::consts::PI;

use core::fmt::Display;

use crate::{self as cubecl, as_type};

Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,12 @@ impl<D: Dialect> CppCompiler<D> {
gpu::Arithmetic::Powi(op) => {
instructions.push(Instruction::Powi(self.compile_binary(op, out)))
}
gpu::Arithmetic::Hypot(op) => {
instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
}
gpu::Arithmetic::Rhypot(op) => {
instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
}
gpu::Arithmetic::Sqrt(op) => {
let op = self.compile_unary(op, out);
instructions.push(self.select_fast_float(
Expand Down
107 changes: 106 additions & 1 deletion crates/cubecl-cpp/src/shared/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ impl<D: Dialect> Binary<D> for Powi {
f.write_str("};\n")
}
}

pub struct ArcTan2;

impl<D: Dialect> Binary<D> for ArcTan2 {
Expand Down Expand Up @@ -370,6 +369,112 @@ impl<D: Dialect> Binary<D> for ArcTan2 {
}
}

pub struct Hypot;

impl<D: Dialect> Binary<D> for Hypot {
// Hypot doesn't support half and no half equivalent exists
fn format_scalar<Lhs, Rhs>(
f: &mut Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
item: Item<D>,
) -> std::fmt::Result
where
Lhs: Component<D>,
Rhs: Component<D>,
{
let elem = item.elem;
let lhs = lhs.to_string();
let rhs = rhs.to_string();
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
let lhs = format!("float({lhs})");
let rhs = format!("float({rhs})");
write!(f, "{elem}(")?;
D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
write!(f, ")")
}
_ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
}
}

// Hypot doesn't support half and no half equivalent exists
fn unroll_vec(
f: &mut Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let item_out = out.item();
let index = out.item().vectorization;

let out = out.fmt_left();
writeln!(f, "{out} = {item_out}{{")?;
for i in 0..index {
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);

Self::format_scalar(f, lhsi, rhsi, item_out)?;
f.write_str(", ")?;
}

f.write_str("};\n")
}
}

pub struct Rhypot;

impl<D: Dialect> Binary<D> for Rhypot {
// Rhypot doesn't support half and no half equivalent exists
fn format_scalar<Lhs, Rhs>(
f: &mut Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
item: Item<D>,
) -> std::fmt::Result
where
Lhs: Component<D>,
Rhs: Component<D>,
{
let elem = item.elem;
let lhs = lhs.to_string();
let rhs = rhs.to_string();
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
let lhs = format!("float({lhs})");
let rhs = format!("float({rhs})");
write!(f, "{elem}(")?;
D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
write!(f, ")")
}
_ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
}
}

// Rhypot doesn't support half and no half equivalent exists
fn unroll_vec(
f: &mut Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let item_out = out.item();
let index = out.item().vectorization;

let out = out.fmt_left();
writeln!(f, "{out} = {item_out}{{")?;
for i in 0..index {
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);

Self::format_scalar(f, lhsi, rhsi, item_out)?;
f.write_str(", ")?;
}

f.write_str("};\n")
}
}

pub struct Max;

impl<D: Dialect> Binary<D> for Max {
Expand Down
26 changes: 26 additions & 0 deletions crates/cubecl-cpp/src/shared/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,32 @@ pub trait DialectInstructions<D: Dialect> {
}
}

fn compile_instruction_hypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
elem: Elem<D>,
) -> std::fmt::Result {
match elem {
Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
_ => panic!("Unsupported type for hypot"),
}
}

fn compile_instruction_rhypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
elem: Elem<D>,
) -> std::fmt::Result {
match elem {
Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
_ => panic!("Unsupported type for hypot"),
}
}

fn compile_instruction_half_function_name_prefix() -> &'static str {
"h"
}
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ pub enum Instruction<D: Dialect> {
Powf(BinaryInstruction<D>),
FastPowf(BinaryInstruction<D>),
Powi(BinaryInstruction<D>),
Hypot(BinaryInstruction<D>),
Rhypot(BinaryInstruction<D>),
Sqrt(UnaryInstruction<D>),
FastSqrt(UnaryInstruction<D>),
InverseSqrt(UnaryInstruction<D>),
Expand Down Expand Up @@ -564,6 +566,8 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-cpp/src/shared/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ function!(Cos, "cos");
function!(Tan, "tan", false);
function!(Sinh, "sinh", false);
function!(Cosh, "cosh", false);
// Tanh is separate below, idk why
function!(ArcCos, "acos", false);
function!(ArcSin, "asin", false);
function!(ArcTan, "atan", false);
Expand Down
Loading