diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 54d8884aa..da09849a6 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -40,6 +40,8 @@ pub trait Float: + ArcTan2 + Powf + Powi + + Hypot + + Rhypot + Sqrt + InverseSqrt + Round diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 5bdb234b1..4ba2cd820 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -262,6 +262,8 @@ impl Radians for ElemExpand {} impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Powi for ElemExpand {} +impl Hypot for ElemExpand {} +impl Rhypot for ElemExpand {} impl Sqrt for ElemExpand {} impl InverseSqrt for ElemExpand {} impl Round for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/mod.rs b/crates/cubecl-core/src/frontend/mod.rs index cefd84180..f51152dc5 100644 --- a/crates/cubecl-core/src/frontend/mod.rs +++ b/crates/cubecl-core/src/frontend/mod.rs @@ -16,6 +16,7 @@ mod options; mod plane; mod polyfills; mod topology; +mod trigonometry; mod validation; pub use branch::{RangeExpand, SteppedRangeExpand, range, range_stepped}; @@ -30,6 +31,7 @@ pub use options::*; pub use plane::*; pub use polyfills::*; pub use topology::*; +pub use trigonometry::*; pub use validation::*; pub use crate::{debug_print, debug_print_expand}; diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index 53be63601..caaf09306 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -252,6 +252,31 @@ impl_binary_func!( f32, f64 ); + +impl_binary_func!( + Hypot, + hypot, + Arithmetic::Hypot, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); + +impl_binary_func!( + Rhypot, + rhypot, + Arithmetic::Rhypot, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); + impl_binary_func!( ArcTan2, atan2, diff --git a/crates/cubecl-core/src/frontend/trigonometry.rs b/crates/cubecl-core/src/frontend/trigonometry.rs new file mode 100644 index 000000000..ee2ee069b --- /dev/null +++ b/crates/cubecl-core/src/frontend/trigonometry.rs @@ -0,0 +1,62 @@ +use cubecl_ir::{ExpandElement, Variable}; + +use crate::prelude::*; +use crate::{self as cubecl}; + +/// Computes the hypotenuse of a right triangle given the lengths of the other two sides. +/// +/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids +/// overflow and underflow issues. +#[cube] +pub fn hypot(lhs: Line, rhs: Line) -> Line { + let one = Line::empty(lhs.size()).fill(F::from_int(1)); + let a = Abs::abs(lhs); + let b = Abs::abs(rhs); + let max_val = Max::max(a, b); + let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0))); + let max_val_safe = select_many(max_val_is_zero, one, max_val); + let min_val = Min::min(a, b); + let t = min_val / max_val_safe; + + max_val * Sqrt::sqrt(fma(t, t, one)) +} + +#[allow(missing_docs)] +pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) { + scope.register_type::>(lhs.ty.storage_type()); + let res = hypot::expand::>( + scope, + ExpandElement::Plain(lhs).into(), + ExpandElement::Plain(rhs).into(), + ); + assign::expand_no_check(scope, res, ExpandElement::Plain(out).into()); +} + +/// Computes the reciprocal of the hypotenuse of a right triangle given the lengths of the other two sides. +/// +/// This function computes `1 / sqrt(x² + y²)` in a numerically stable way that avoids +/// overflow and underflow issues. +#[cube] +pub fn rhypot(lhs: Line, rhs: Line) -> Line { + let one = Line::empty(lhs.size()).fill(F::from_int(1)); + let a = Abs::abs(lhs); + let b = Abs::abs(rhs); + let max_val = Max::max(a, b); + let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0))); + let max_val_safe = select_many(max_val_is_zero, one, max_val); + let min_val = Min::min(a, b); + let t = min_val / max_val_safe; + + InverseSqrt::inverse_sqrt(fma(t, t, one)) / max_val +} + +#[allow(missing_docs)] +pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) { + scope.register_type::>(lhs.ty.storage_type()); + let res = rhypot::expand::>( + scope, + ExpandElement::Plain(lhs).into(), + ExpandElement::Plain(rhs).into(), + ); + assign::expand_no_check(scope, res, ExpandElement::Plain(out).into()); +} diff --git a/crates/cubecl-core/src/runtime_tests/binary.rs b/crates/cubecl-core/src/runtime_tests/binary.rs index dabd338ce..0143e1e92 100644 --- a/crates/cubecl-core/src/runtime_tests/binary.rs +++ b/crates/cubecl-core/src/runtime_tests/binary.rs @@ -1,5 +1,7 @@ #![allow(clippy::approx_constant)] +use core::f32; + use std::{fmt::Display, sync::LazyLock}; use crate::{self as cubecl, as_type}; @@ -32,7 +34,11 @@ pub(crate) fn assert_equals_approx< // account for lower precision at higher values let allowed_error = F::new((epsilon * e.to_f32().unwrap().abs()).max(epsilon)); assert!( - (*a - *e).abs() < allowed_error || (a.is_nan() && e.is_nan()), + (*a - *e).abs() < allowed_error + || (a.is_nan() && e.is_nan()) + || (a.is_infinite() + && e.is_infinite() + && a.is_sign_positive() == e.is_sign_positive()), "Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={} index: {} actual: {:?} @@ -186,6 +192,64 @@ test_binary_impl!( ] ); +test_binary_impl!( + test_hypot, + F, + F::hypot, + [ + { + input_vectorization: 1, + out_vectorization: 1, + lhs: as_type![F: 3., 0., 5., 0.], + rhs: as_type![F: 4., 5., 0., 0.], + expected: as_type![F: 5., 5., 5., 0.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + lhs: as_type![F: 3., 0., 5., 8.], + rhs: as_type![F: 4., 5., 0., 15.], + expected: as_type![F: 5., 5., 5., 17.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + lhs: as_type![F: -3., 0., -5., -8.], + rhs: as_type![F: -4., -5., 0., 15.], + expected: as_type![F: 5., 5., 5., 17.] + } + ] +); + +test_binary_impl!( + test_rhypot, + F, + F::rhypot, + [ + { + input_vectorization: 1, + out_vectorization: 1, + lhs: as_type![F: 3., 0., 5., 0.], + rhs: as_type![F: 4., 5., 0., 0.], + expected: &[F::new(0.2), F::new(0.2), F::new(0.2), F::INFINITY] + }, + { + input_vectorization: 2, + out_vectorization: 2, + lhs: as_type![F: 3., 0., 5., 0.3], + rhs: as_type![F: 4., 5., 0., 0.4], + expected: as_type![F: 0.2, 0.2, 0.2, 2.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + lhs: as_type![F: 0., 0., -5., -0.3], + rhs: as_type![F: -1., -5., 0., -0.4], + expected: as_type![F: 1., 0.2, 0.2, 2.] + } + ] +); + #[cube(launch_unchecked)] fn test_powi_kernel( lhs: &Array>, @@ -356,6 +420,8 @@ macro_rules! testgen_binary { add_test!(test_dot); add_test!(test_powf); + add_test!(test_hypot); + add_test!(test_rhypot); add_test!(test_powi); add_test!(test_atan2); } diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index f97776f1f..ca84ef2a2 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -480,8 +480,8 @@ pub fn test_simple_1_expected() -> Vec { // let lhs: Vec = (0..64).map(|i| f16::from_f32(i as f32)).collect(); // let rhs: Vec = (0..64).map(|i| f16::from_f32((i % 8) as f32)).collect(); -// let lhs = client.create(f16::as_bytes(&lhs)); -// let rhs = client.create(f16::as_bytes(&rhs)); +// let lhs = client.create_from_slice(f16::as_bytes(&lhs)); +// let rhs = client.create_from_slice(f16::as_bytes(&rhs)); // let out = client.empty(core::mem::size_of::() * 64); // unsafe { diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index bd5682e37..e9125b66b 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -1,7 +1,9 @@ #![allow(clippy::approx_constant)] -use std::f32::consts::PI; -use std::fmt::Display; +use core::f32; +use core::f32::consts::PI; + +use core::fmt::Display; use crate::{self as cubecl, as_type}; diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 5420ab88e..a532dcce7 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1250,6 +1250,12 @@ impl CppCompiler { gpu::Arithmetic::Powi(op) => { instructions.push(Instruction::Powi(self.compile_binary(op, out))) } + gpu::Arithmetic::Hypot(op) => { + instructions.push(Instruction::Hypot(self.compile_binary(op, out))) + } + gpu::Arithmetic::Rhypot(op) => { + instructions.push(Instruction::Rhypot(self.compile_binary(op, out))) + } gpu::Arithmetic::Sqrt(op) => { let op = self.compile_unary(op, out); instructions.push(self.select_fast_float( diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index b7648afc6..a21a8b8da 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -324,7 +324,6 @@ impl Binary for Powi { f.write_str("};\n") } } - pub struct ArcTan2; impl Binary for ArcTan2 { @@ -370,6 +369,112 @@ impl Binary for ArcTan2 { } } +pub struct Hypot; + +impl Binary for Hypot { + // Hypot doesn't support half and no half equivalent exists + fn format_scalar( + f: &mut Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + item: Item, + ) -> std::fmt::Result + where + Lhs: Component, + Rhs: Component, + { + let elem = item.elem; + let lhs = lhs.to_string(); + let rhs = rhs.to_string(); + match elem { + Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { + let lhs = format!("float({lhs})"); + let rhs = format!("float({rhs})"); + write!(f, "{elem}(")?; + D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?; + write!(f, ")") + } + _ => D::compile_instruction_hypot(f, &lhs, &rhs, elem), + } + } + + // Hypot doesn't support half and no half equivalent exists + fn unroll_vec( + f: &mut Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> core::fmt::Result { + let item_out = out.item(); + let index = out.item().vectorization; + + let out = out.fmt_left(); + writeln!(f, "{out} = {item_out}{{")?; + for i in 0..index { + let lhsi = lhs.index(i); + let rhsi = rhs.index(i); + + Self::format_scalar(f, lhsi, rhsi, item_out)?; + f.write_str(", ")?; + } + + f.write_str("};\n") + } +} + +pub struct Rhypot; + +impl Binary for Rhypot { + // Rhypot doesn't support half and no half equivalent exists + fn format_scalar( + f: &mut Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + item: Item, + ) -> std::fmt::Result + where + Lhs: Component, + Rhs: Component, + { + let elem = item.elem; + let lhs = lhs.to_string(); + let rhs = rhs.to_string(); + match elem { + Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { + let lhs = format!("float({lhs})"); + let rhs = format!("float({rhs})"); + write!(f, "{elem}(")?; + D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?; + write!(f, ")") + } + _ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem), + } + } + + // Rhypot doesn't support half and no half equivalent exists + fn unroll_vec( + f: &mut Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> core::fmt::Result { + let item_out = out.item(); + let index = out.item().vectorization; + + let out = out.fmt_left(); + writeln!(f, "{out} = {item_out}{{")?; + for i in 0..index { + let lhsi = lhs.index(i); + let rhsi = rhs.index(i); + + Self::format_scalar(f, lhsi, rhsi, item_out)?; + f.write_str(", ")?; + } + + f.write_str("};\n") + } +} + pub struct Max; impl Binary for Max { diff --git a/crates/cubecl-cpp/src/shared/dialect.rs b/crates/cubecl-cpp/src/shared/dialect.rs index 7870a576b..3691a26ad 100644 --- a/crates/cubecl-cpp/src/shared/dialect.rs +++ b/crates/cubecl-cpp/src/shared/dialect.rs @@ -652,6 +652,32 @@ pub trait DialectInstructions { } } + fn compile_instruction_hypot( + f: &mut std::fmt::Formatter<'_>, + lhs: &str, + rhs: &str, + elem: Elem, + ) -> std::fmt::Result { + match elem { + Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"), + Elem::F64 => write!(f, "hypot({lhs}, {rhs})"), + _ => panic!("Unsupported type for hypot"), + } + } + + fn compile_instruction_rhypot( + f: &mut std::fmt::Formatter<'_>, + lhs: &str, + rhs: &str, + elem: Elem, + ) -> std::fmt::Result { + match elem { + Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"), + Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"), + _ => panic!("Unsupported type for hypot"), + } + } + fn compile_instruction_half_function_name_prefix() -> &'static str { "h" } diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index d926c1cf1..ec70c1b62 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -187,6 +187,8 @@ pub enum Instruction { Powf(BinaryInstruction), FastPowf(BinaryInstruction), Powi(BinaryInstruction), + Hypot(BinaryInstruction), + Rhypot(BinaryInstruction), Sqrt(UnaryInstruction), FastSqrt(UnaryInstruction), InverseSqrt(UnaryInstruction), @@ -564,6 +566,8 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out), Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 419e8d2a4..54c40c8da 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -155,7 +155,6 @@ function!(Cos, "cos"); function!(Tan, "tan", false); function!(Sinh, "sinh", false); function!(Cosh, "cosh", false); -// Tanh is separate below, idk why function!(ArcCos, "acos", false); function!(ArcSin, "asin", false); function!(ArcTan, "atan", false); diff --git a/crates/cubecl-cpu/src/compiler/mod.rs b/crates/cubecl-cpu/src/compiler/mod.rs index a5ea3f239..a3f4d49a7 100644 --- a/crates/cubecl-cpu/src/compiler/mod.rs +++ b/crates/cubecl-cpu/src/compiler/mod.rs @@ -24,7 +24,10 @@ use cubecl_core::{ use cubecl_opt::OptimizerBuilder; use mlir_engine::MlirEngine; -use crate::compiler::passes::erf_transform::ErfTransform; +use crate::compiler::passes::{ + erf_transform::ErfTransform, + trigonometries_transform::{HypotTransform, RhypotTransform}, +}; #[derive(Clone, Debug, Default)] pub struct MlirCompiler {} @@ -61,6 +64,8 @@ impl Compiler for MlirCompiler { dump_scope(&kernel.body, &kernel.options.kernel_name); let opt = OptimizerBuilder::default() .with_transformer(ErfTransform) + .with_transformer(HypotTransform) + .with_transformer(RhypotTransform) .with_processor(CheckedIoProcessor::new(mode)) .with_processor(SaturatingArithmeticProcessor::new(true)) .with_processor(PredicateProcessor) diff --git a/crates/cubecl-cpu/src/compiler/passes/mod.rs b/crates/cubecl-cpu/src/compiler/passes/mod.rs index b094dac0e..bd386e5d7 100644 --- a/crates/cubecl-cpu/src/compiler/passes/mod.rs +++ b/crates/cubecl-cpu/src/compiler/passes/mod.rs @@ -1,2 +1,3 @@ pub mod erf_transform; pub mod shared_memories; +pub mod trigonometries_transform; diff --git a/crates/cubecl-cpu/src/compiler/passes/trigonometries_transform.rs b/crates/cubecl-cpu/src/compiler/passes/trigonometries_transform.rs new file mode 100644 index 000000000..1f77db951 --- /dev/null +++ b/crates/cubecl-cpu/src/compiler/passes/trigonometries_transform.rs @@ -0,0 +1,37 @@ +use cubecl_core::{ + ir::{Arithmetic, Instruction, Operation, Scope}, + prelude::*, +}; +use cubecl_opt::{IrTransformer, TransformAction}; + +#[derive(Debug)] +pub(crate) struct HypotTransform; + +impl IrTransformer for HypotTransform { + fn maybe_transform(&self, scope: &mut Scope, inst: &Instruction) -> TransformAction { + match &inst.operation { + Operation::Arithmetic(Arithmetic::Hypot(op)) => { + let mut scope = scope.child(); + expand_hypot(&mut scope, op.lhs, op.rhs, inst.out.unwrap()); + TransformAction::Replace(scope.process([]).instructions) + } + _ => TransformAction::Ignore, + } + } +} + +#[derive(Debug)] +pub(crate) struct RhypotTransform; + +impl IrTransformer for RhypotTransform { + fn maybe_transform(&self, scope: &mut Scope, inst: &Instruction) -> TransformAction { + match &inst.operation { + Operation::Arithmetic(Arithmetic::Rhypot(op)) => { + let mut scope = scope.child(); + expand_rhypot(&mut scope, op.lhs, op.rhs, inst.out.unwrap()); + TransformAction::Replace(scope.process([]).instructions) + } + _ => TransformAction::Ignore, + } + } +} diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index ab456cf3a..c233af3cd 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -223,7 +223,7 @@ impl<'a> Visitor<'a> { self.insert_variable(out, operation); } Arithmetic::Erf(_) => { - unreachable!("Should have been transformed in primitive in a previous passe"); + unreachable!("Should have been transformed in primitive in a previous phase"); } Arithmetic::Exp(exp) => { let value = self.get_variable(exp.input); @@ -248,10 +248,8 @@ impl<'a> Visitor<'a> { let b = self.get_variable(fma.b); let c = self.get_variable(fma.c); - let result_type = fma.a.ty.to_type(self.context); - let result = self.append_operation_with_result(vector::fma( + let result = self.append_operation_with_result(llvm_ods::intr_fma( self.context, - result_type, a, b, c, @@ -446,6 +444,12 @@ impl<'a> Visitor<'a> { self.append_operation_with_result(arith::mulf(value, f, self.location)); self.insert_variable(out, result); } + Arithmetic::Hypot(_hypot) => { + unreachable!("Should have been transformed in primitive in a previous phase"); + } + Arithmetic::Rhypot(_rhypot) => { + unreachable!("Should have been transformed in primitive in a previous phase"); + } Arithmetic::Recip(recip) => { let value = self.get_variable(recip.input); let one = self.create_float_constant_from_item(recip.input.ty, 1.0); @@ -506,17 +510,14 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } - Arithmetic::InverseSqrt(sqrt) => { - let input = self.get_variable(sqrt.input); - let value = self.append_operation_with_result(llvm_ods::intr_sqrt( + Arithmetic::InverseSqrt(rsqrt) => { + let input = self.get_variable(rsqrt.input); + let output = self.append_operation_with_result(math_ods::rsqrt( self.context, input, self.location, )); - let one = self.create_float_constant_from_item(sqrt.input.ty, 1.0); - let recip = - self.append_operation_with_result(arith::divf(one, value, self.location)); - self.insert_variable(out, recip); + self.insert_variable(out, output); } Arithmetic::Sqrt(sqrt) => { let input = self.get_variable(sqrt.input); diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index f4ff66568..e502502a3 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -40,6 +40,8 @@ pub enum Arithmetic { ArcTan2(BinaryOperator), Powf(BinaryOperator), Powi(BinaryOperator), + Hypot(BinaryOperator), + Rhypot(BinaryOperator), Sqrt(UnaryOperator), InverseSqrt(UnaryOperator), Round(UnaryOperator), @@ -95,6 +97,8 @@ impl Display for Arithmetic { Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs), Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs), + Arithmetic::Hypot(op) => write!(f, "{}.hypot({})", op.lhs, op.rhs), + Arithmetic::Rhypot(op) => write!(f, "{}.rhypot({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), Arithmetic::InverseSqrt(op) => write!(f, "{}.inverse_sqrt()", op.input), Arithmetic::Round(op) => write!(f, "{}.round()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index 08dfe76b1..04ce26a59 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -156,6 +156,14 @@ impl ScopeProcessing { Arithmetic::Powi(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); } + Arithmetic::Hypot(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); + sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); + } + Arithmetic::Rhypot(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); + sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); + } Arithmetic::Sqrt(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 50e05a615..b98f5c13e 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -79,6 +79,8 @@ impl Optimizer { | Arithmetic::Div(binary_operator) | Arithmetic::Powf(binary_operator) | Arithmetic::Powi(binary_operator) + | Arithmetic::Hypot(binary_operator) + | Arithmetic::Rhypot(binary_operator) | Arithmetic::Modulo(binary_operator) | Arithmetic::Max(binary_operator) | Arithmetic::Min(binary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index fc4bb31d4..a927915bf 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -504,7 +504,11 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option } }) } - Arithmetic::Erf(_) | Arithmetic::Magnitude(_) | Arithmetic::Normalize(_) => None, + Arithmetic::Erf(_) + | Arithmetic::Hypot(_) + | Arithmetic::Rhypot(_) + | Arithmetic::Magnitude(_) + | Arithmetic::Normalize(_) => None, } } diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 1b93e4525..b74afbc06 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -535,6 +535,12 @@ impl SpirvCompiler { b.select(ty, Some(out), is_zero, even, sel1).unwrap(); }) } + Arithmetic::Hypot(_op) => { + unreachable!("Replaced by transformer"); + } + Arithmetic::Rhypot(_op) => { + unreachable!("Replaced by transformer"); + } Arithmetic::Sqrt(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { b.declare_math_mode(modes, out); diff --git a/crates/cubecl-spirv/src/compiler.rs b/crates/cubecl-spirv/src/compiler.rs index 6a8aedce1..198164746 100644 --- a/crates/cubecl-spirv/src/compiler.rs +++ b/crates/cubecl-spirv/src/compiler.rs @@ -34,7 +34,7 @@ use crate::{ item::Item, lookups::LookupTables, target::{GLCompute, SpirvTarget}, - transformers::{BitwiseTransform, ErfTransform}, + transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform}, }; pub const MAX_VECTORIZATION: u32 = 4; @@ -229,6 +229,8 @@ impl SpirvCompiler { let mut opt = OptimizerBuilder::default() .with_transformer(ErfTransform) .with_transformer(BitwiseTransform) + .with_transformer(HypotTransform) + .with_transformer(RhypotTransform) .with_processor(CheckedIoProcessor::new(self.mode)) .with_processor(UnrollProcessor::new(MAX_VECTORIZATION)) .with_processor(SaturatingArithmeticProcessor::new(true)) diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index f0e775b4e..8d387b422 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -139,7 +139,6 @@ pub mod glcompute { fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { b.gl_pow_id(ty, Some(out), lhs, rhs).unwrap(); } - fn exp(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { b.gl_exp_id(ty, Some(out), input).unwrap(); } diff --git a/crates/cubecl-spirv/src/transformers.rs b/crates/cubecl-spirv/src/transformers.rs index d0b0812b4..260a1578d 100644 --- a/crates/cubecl-spirv/src/transformers.rs +++ b/crates/cubecl-spirv/src/transformers.rs @@ -3,7 +3,7 @@ use cubecl_core::{ Arithmetic, Bitwise, ElemType, ExpandElement, Instruction, IntKind, Operation, Scope, UIntKind, Variable, }, - prelude::{IntExpand, assign, expand_erf}, + prelude::{IntExpand, assign, expand_erf, expand_hypot, expand_rhypot}, }; use cubecl_opt::{IrTransformer, TransformAction}; @@ -26,6 +26,40 @@ impl IrTransformer for ErfTransform { } } +/// Expand hypot +#[derive(Debug)] +pub(crate) struct HypotTransform; + +impl IrTransformer for HypotTransform { + fn maybe_transform(&self, scope: &mut Scope, inst: &Instruction) -> TransformAction { + match &inst.operation { + Operation::Arithmetic(Arithmetic::Hypot(op)) => { + let mut scope = scope.child(); + expand_hypot(&mut scope, op.lhs, op.rhs, inst.out.unwrap()); + TransformAction::Replace(into_instructions(scope)) + } + _ => TransformAction::Ignore, + } + } +} + +/// Expand hypot +#[derive(Debug)] +pub(crate) struct RhypotTransform; + +impl IrTransformer for RhypotTransform { + fn maybe_transform(&self, scope: &mut Scope, inst: &Instruction) -> TransformAction { + match &inst.operation { + Operation::Arithmetic(Arithmetic::Rhypot(op)) => { + let mut scope = scope.child(); + expand_rhypot(&mut scope, op.lhs, op.rhs, inst.out.unwrap()); + TransformAction::Replace(into_instructions(scope)) + } + _ => TransformAction::Ignore, + } + } +} + /// Transform operations that only support 32 bits using polyfills #[derive(Debug)] pub(crate) struct BitwiseTransform; diff --git a/crates/cubecl-std/src/tensor/identity.rs b/crates/cubecl-std/src/tensor/identity.rs index 9cf09c2a3..61d3f9608 100644 --- a/crates/cubecl-std/src/tensor/identity.rs +++ b/crates/cubecl-std/src/tensor/identity.rs @@ -20,7 +20,7 @@ fn identity_kernel( let mut offset = 0; while offset < output.line_size() { let remainder = (start_pos + offset) % gap; - if remainder.is_multiple_of(gap) { + if remainder == 0 { line[offset] = C::from_int(1); offset += gap; } else { diff --git a/crates/cubecl-std/src/tests/trigonometry.rs b/crates/cubecl-std/src/tests/trigonometry.rs index aa7acb9b2..63e76b973 100644 --- a/crates/cubecl-std/src/tests/trigonometry.rs +++ b/crates/cubecl-std/src/tests/trigonometry.rs @@ -1,6 +1,6 @@ +use core::f32::consts::{PI, TAU}; use cubecl::prelude::*; use cubecl_core as cubecl; -use std::f32::consts::{PI, TAU}; use crate::trigonometry::*; @@ -82,49 +82,6 @@ pub fn test_to_radians(client: ComputeClient) { } } -#[cube(launch_unchecked)] -fn kernel_hypot(x: &Array, y: &Array, output: &mut Array) { - if UNIT_POS < x.len() { - output[UNIT_POS] = hypot::(x[UNIT_POS], y[UNIT_POS]); - } -} - -#[allow(clippy::approx_constant)] -pub fn test_hypot(client: ComputeClient) { - let x_data = vec![3.0, 0.0, 1.0, 5.0, 0.0]; - let y_data = vec![4.0, 1.0, 1.0, 12.0, 0.0]; - let expected = [5.0, 1.0, 1.414_213_5, 13.0, 0.0]; - - let x = client.create_from_slice(f32::as_bytes(&x_data)); - let y = client.create_from_slice(f32::as_bytes(&y_data)); - let output = client.empty(x_data.len() * core::mem::size_of::()); - - unsafe { - kernel_hypot::launch_unchecked( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(x_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&x, x_data.len(), 1), - ArrayArg::from_raw_parts::(&y, y_data.len(), 1), - ArrayArg::from_raw_parts::(&output, x_data.len(), 1), - ) - .unwrap(); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { - assert!( - (expected_val - actual_val).abs() < 1e-5, - "Hypot test {} failed: expected {}, got {}", - i, - expected_val, - actual_val - ); - } -} - #[macro_export] macro_rules! testgen_trigonometry { () => { @@ -143,12 +100,6 @@ macro_rules! testgen_trigonometry { let client = TestRuntime::client(&Default::default()); test_to_radians::(client); } - - #[test] - fn test_hypot_computation() { - let client = TestRuntime::client(&Default::default()); - test_hypot::(client); - } } }; } diff --git a/crates/cubecl-std/src/trigonometry.rs b/crates/cubecl-std/src/trigonometry.rs index 355bc101b..631251015 100644 --- a/crates/cubecl-std/src/trigonometry.rs +++ b/crates/cubecl-std/src/trigonometry.rs @@ -34,28 +34,3 @@ pub fn to_degrees(val: F) -> F { pub fn to_radians(val: F) -> F { val * F::new(f32::consts::PI / 180.0) } - -/// Computes the hypotenuse of a right triangle given the lengths of the other two sides. -/// -/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids -/// overflow and underflow issues. -/// -/// # Arguments -/// -/// * `x` - Length of one side -/// * `y` - Length of the other side -/// -/// # Returns -/// -/// The length of the hypotenuse -/// -/// # Example -/// -/// ```rust,ignore -/// let hyp = hypot(F::new(3.0), F::new(4.0)); -/// assert!((hyp - F::new(5.0)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn hypot(x: F, y: F) -> F { - F::sqrt(x * x + y * y) -} diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 57e0fd895..7592a8c3c 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -874,6 +874,17 @@ impl WgslCompiler { out: self.compile_variable(out), }) } + cube::Arithmetic::Hypot(op) => { + let mut scope = scope.child(); + expand_hypot(&mut scope, op.lhs, op.rhs, out); + instructions.extend(self.compile_scope(&mut scope)); + } + cube::Arithmetic::Rhypot(op) => { + let mut scope = scope.child(); + expand_rhypot(&mut scope, op.lhs, op.rhs, out); + instructions.extend(self.compile_scope(&mut scope)); + } + cube::Arithmetic::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt { input: self.compile_variable(op.input), out: self.compile_variable(out), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs b/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs index 4ccfa92a3..0fa79de79 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs @@ -126,7 +126,7 @@ pub fn call_powf( let lhs = lhs.to_string(); (lhs, rhs, POWF_SCALAR) } else { - // When vecotized, we make sure the function inputs shared the same vectorization factor as + // When vectorized, we make sure the function inputs shared the same vectorization factor as // the output. let rhs = rhs.fmt_cast_to(out.item()); let lhs = lhs.fmt_cast_to(out.item());