Skip to content

Commit ee3287c

Browse files
ejoepesejoepes
authored andcommitted
Move hypot/rhypot to a function.
1 parent 721b01f commit ee3287c

File tree

14 files changed

+183
-351
lines changed

14 files changed

+183
-351
lines changed

crates/cubecl-core/src/frontend/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mod options;
1616
mod plane;
1717
mod polyfills;
1818
mod topology;
19+
mod trigonometry;
1920

2021
pub use branch::{RangeExpand, SteppedRangeExpand, range, range_stepped};
2122
pub use const_expand::*;
@@ -29,5 +30,6 @@ pub use options::*;
2930
pub use plane::*;
3031
pub use polyfills::*;
3132
pub use topology::*;
33+
pub use trigonometry::*;
3234

3335
pub use crate::{debug_print, debug_print_expand};
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use cubecl_ir::{ExpandElement, Variable};
2+
3+
use crate::prelude::*;
4+
use crate::{self as cubecl};
5+
6+
/// Computes the hypotenuse of a right triangle given the lengths of the other two sides.
7+
///
8+
/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids
9+
/// overflow and underflow issues.
10+
///
11+
/// # Arguments
12+
///
13+
/// * `x` - Length of one side
14+
/// * `y` - Length of the other side
15+
///
16+
/// # Returns
17+
///
18+
/// The length of the hypotenuse
19+
///
20+
/// # Example
21+
///
22+
/// ```rust,ignore
23+
/// let hyp = hypot(F::new(3.0), F::new(4.0));
24+
/// assert!((hyp - F::new(5.0)).abs() < F::new(1e-6));
25+
/// ```
26+
#[cube]
27+
pub fn hypot<F: Float>(lhs: F, rhs: F) -> F {
28+
let one = F::from_int(1);
29+
let a = F::abs(lhs);
30+
let b = F::abs(rhs);
31+
let max_val = F::max(a, b);
32+
let max_val_is_zero = max_val == F::from_int(0);
33+
let max_val_safe = select(max_val_is_zero, one, max_val);
34+
let min_val = F::min(a, b);
35+
let t = min_val / max_val_safe;
36+
37+
max_val * F::sqrt(one + (t * t))
38+
}
39+
40+
#[allow(missing_docs)]
41+
pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
42+
scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
43+
let res = hypot::expand::<FloatExpand<0>>(
44+
scope,
45+
ExpandElement::Plain(lhs).into(),
46+
ExpandElement::Plain(rhs).into(),
47+
);
48+
assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
49+
}
50+
51+
#[cube]
52+
pub fn rhypot<F: Float>(lhs: F, rhs: F) -> F {
53+
let one = F::from_int(1);
54+
let a = F::abs(lhs);
55+
let b = F::abs(rhs);
56+
let max_val = F::max(a, b);
57+
let max_val_is_zero = max_val == F::from_int(0);
58+
let max_val_safe = select(max_val_is_zero, one, max_val);
59+
let min_val = F::min(a, b);
60+
let t = min_val / max_val_safe;
61+
62+
F::inverse_sqrt(one + (t * t)) / max_val
63+
}
64+
65+
#[allow(missing_docs)]
66+
pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
67+
scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
68+
let res = rhypot::expand::<FloatExpand<0>>(
69+
scope,
70+
ExpandElement::Plain(lhs).into(),
71+
ExpandElement::Plain(rhs).into(),
72+
);
73+
assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
74+
}

crates/cubecl-cpu/src/compiler/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ use cubecl_core::{
2323
use cubecl_opt::OptimizerBuilder;
2424
use mlir_engine::MlirEngine;
2525

26-
use crate::compiler::passes::erf_transform::ErfTransform;
26+
use crate::compiler::passes::{
27+
erf_transform::ErfTransform,
28+
trigonometries_transform::{HypotTransform, RhypotTransform},
29+
};
2730

2831
#[derive(Clone, Debug, Default)]
2932
pub struct MlirCompiler {}
@@ -46,6 +49,8 @@ impl Compiler for MlirCompiler {
4649
dump_scope(&kernel.body, &kernel.options.kernel_name);
4750
let opt = OptimizerBuilder::default()
4851
.with_transformer(ErfTransform)
52+
.with_transformer(HypotTransform)
53+
.with_transformer(RhypotTransform)
4954
.with_processor(CheckedIoProcessor::new(mode))
5055
.with_processor(SaturatingArithmeticProcessor::new(true))
5156
.with_processor(PredicateProcessor)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod erf_transform;
22
pub mod shared_memories;
3+
pub mod trigonometries_transform;
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use cubecl_core::{
2+
ir::{Arithmetic, Instruction, Operation, Scope},
3+
prelude::*,
4+
};
5+
use cubecl_opt::{IrTransformer, TransformAction};
6+
7+
#[derive(Debug)]
8+
pub(crate) struct HypotTransform;
9+
10+
impl IrTransformer for HypotTransform {
11+
fn maybe_transform(&self, scope: &mut Scope, inst: &Instruction) -> TransformAction {
12+
match &inst.operation {
13+
Operation::Arithmetic(Arithmetic::Hypot(op)) => {
14+
let mut scope = scope.child();
15+
expand_hypot(&mut scope, op.lhs, op.rhs, inst.out.unwrap());
16+
TransformAction::Replace(scope.process([]).instructions)
17+
}
18+
_ => TransformAction::Ignore,
19+
}
20+
}
21+
}
22+
23+
#[derive(Debug)]
24+
pub(crate) struct RhypotTransform;
25+
26+
impl IrTransformer for RhypotTransform {
27+
fn maybe_transform(&self, scope: &mut Scope, inst: &Instruction) -> TransformAction {
28+
match &inst.operation {
29+
Operation::Arithmetic(Arithmetic::Rhypot(op)) => {
30+
let mut scope = scope.child();
31+
expand_rhypot(&mut scope, op.lhs, op.rhs, inst.out.unwrap());
32+
TransformAction::Replace(scope.process([]).instructions)
33+
}
34+
_ => TransformAction::Ignore,
35+
}
36+
}
37+
}

crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs

Lines changed: 5 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ impl<'a> Visitor<'a> {
223223
self.insert_variable(out, operation);
224224
}
225225
Arithmetic::Erf(_) => {
226-
unreachable!("Should have been transformed in primitive in a previous passe");
226+
unreachable!("Should have been transformed in primitive in a previous phase");
227227
}
228228
Arithmetic::Exp(exp) => {
229229
let value = self.get_variable(exp.input);
@@ -446,84 +446,11 @@ impl<'a> Visitor<'a> {
446446
self.append_operation_with_result(arith::mulf(value, f, self.location));
447447
self.insert_variable(out, result);
448448
}
449-
Arithmetic::Hypot(hypot) => {
450-
let (a, b) = self.get_binary_op_variable(hypot.lhs, hypot.rhs);
451-
let abs_a = self.get_absolute_val(hypot.lhs.ty, a);
452-
let abs_b = self.get_absolute_val(hypot.rhs.ty, b);
453-
let zero = self.create_float_constant_from_item(hypot.lhs.ty, 0.0);
454-
let one = self.create_float_constant_from_item(hypot.lhs.ty, 1.0);
455-
let max =
456-
self.append_operation_with_result(arith::maxnumf(abs_a, abs_b, self.location));
457-
let is_max_zero = self.append_operation_with_result(arith::cmpf(
458-
self.context,
459-
arith::CmpfPredicate::Oeq,
460-
max,
461-
zero,
462-
self.location,
463-
));
464-
let max_safe = self.append_operation_with_result(arith::select(
465-
is_max_zero,
466-
one,
467-
max,
468-
self.location,
469-
));
470-
let min =
471-
self.append_operation_with_result(arith::minimumf(abs_a, abs_b, self.location));
472-
let t =
473-
self.append_operation_with_result(arith::divf(min, max_safe, self.location));
474-
let t_square = self.append_operation_with_result(arith::mulf(t, t, self.location));
475-
let t_square_plus_one =
476-
self.append_operation_with_result(arith::addf(t_square, one, self.location));
477-
let square_root = self.append_operation_with_result(llvm_ods::intr_sqrt(
478-
self.context,
479-
t_square_plus_one,
480-
self.location,
481-
));
482-
let result =
483-
self.append_operation_with_result(arith::mulf(max, square_root, self.location));
484-
485-
self.insert_variable(out, result);
449+
Arithmetic::Hypot(_hypot) => {
450+
unreachable!("Should have been transformed in primitive in a previous phase");
486451
}
487-
Arithmetic::Rhypot(hypot) => {
488-
let (a, b) = self.get_binary_op_variable(hypot.lhs, hypot.rhs);
489-
let abs_a = self.get_absolute_val(hypot.lhs.ty, a);
490-
let abs_b = self.get_absolute_val(hypot.rhs.ty, b);
491-
let zero = self.create_float_constant_from_item(hypot.lhs.ty, 0.0);
492-
let one = self.create_float_constant_from_item(hypot.lhs.ty, 1.0);
493-
let max =
494-
self.append_operation_with_result(arith::maxnumf(abs_a, abs_b, self.location));
495-
let is_max_zero = self.append_operation_with_result(arith::cmpf(
496-
self.context,
497-
arith::CmpfPredicate::Oeq,
498-
max,
499-
zero,
500-
self.location,
501-
));
502-
let max_safe = self.append_operation_with_result(arith::select(
503-
is_max_zero,
504-
one,
505-
max,
506-
self.location,
507-
));
508-
let min =
509-
self.append_operation_with_result(arith::minimumf(abs_a, abs_b, self.location));
510-
let t =
511-
self.append_operation_with_result(arith::divf(min, max_safe, self.location));
512-
let t_square = self.append_operation_with_result(arith::mulf(t, t, self.location));
513-
let t_square_plus_one =
514-
self.append_operation_with_result(arith::addf(t_square, one, self.location));
515-
let inverse_square_root = self.append_operation_with_result(math_ods::rsqrt(
516-
self.context,
517-
t_square_plus_one,
518-
self.location,
519-
));
520-
let result = self.append_operation_with_result(arith::divf(
521-
inverse_square_root,
522-
max,
523-
self.location,
524-
));
525-
526-
self.insert_variable(out, result);
452+
Arithmetic::Rhypot(_rhypot) => {
453+
unreachable!("Should have been transformed in primitive in a previous phase");
527454
}
528455
Arithmetic::Recip(recip) => {
529456
let value = self.get_variable(recip.input);

crates/cubecl-opt/src/passes/constant_prop.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,6 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option<ConstantScalarValue>
344344
Arithmetic::Powi(op) => {
345345
const_eval_float!(op.lhs, op.rhs; num::Float::powf)
346346
}
347-
Arithmetic::Hypot(op) => const_eval_float!(op.lhs, op.rhs; num::Float::hypot),
348-
Arithmetic::Rhypot(op) => const_eval_float!(op.lhs, op.rhs; num::Float::hypot),
349347
Arithmetic::Modulo(op) => const_eval!(% op.lhs, op.rhs),
350348
Arithmetic::Remainder(op) => const_eval!(% op.lhs, op.rhs),
351349
Arithmetic::MulHi(op) => {
@@ -506,7 +504,11 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option<ConstantScalarValue>
506504
}
507505
})
508506
}
509-
Arithmetic::Erf(_) | Arithmetic::Magnitude(_) | Arithmetic::Normalize(_) => None,
507+
Arithmetic::Erf(_)
508+
| Arithmetic::Hypot(_)
509+
| Arithmetic::Rhypot(_)
510+
| Arithmetic::Magnitude(_)
511+
| Arithmetic::Normalize(_) => None,
510512
}
511513
}
512514

crates/cubecl-spirv/src/arithmetic.rs

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -535,87 +535,11 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
535535
b.select(ty, Some(out), is_zero, even, sel1).unwrap();
536536
})
537537
}
538-
Arithmetic::Hypot(op) => {
539-
self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
540-
let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
541-
let zero = b.static_cast(ConstVal::Bit32(0), &Elem::Int(32, false), &out_ty);
542-
let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
543-
let abs_a = b.id();
544-
T::f_abs(b, ty, lhs, abs_a);
545-
let abs_b = b.id();
546-
T::f_abs(b, ty, rhs, abs_b);
547-
let max = b.id();
548-
T::f_max(b, ty, abs_a, abs_b, max);
549-
let min = b.id();
550-
T::f_min(b, ty, abs_a, abs_b, min);
551-
let bool = Elem::Bool.id(b);
552-
let is_max_zero = b.f_ord_equal(bool, None, max, zero).unwrap();
553-
let max_safe = b.id();
554-
b.select(ty, Some(max_safe), is_max_zero, one, max).unwrap();
555-
let t = b.id();
556-
b.f_div(ty, Some(t), min, max_safe).unwrap();
557-
let t_fma = b.gl_fma(ty, t, t, one).unwrap();
558-
let square_root = b.id();
559-
T::sqrt(b, ty, t_fma, square_root);
560-
let ids = [
561-
abs_a,
562-
abs_b,
563-
max,
564-
is_max_zero,
565-
max_safe,
566-
t_fma,
567-
square_root,
568-
out,
569-
];
570-
for id in ids {
571-
b.mark_uniformity(id, uniform);
572-
if relaxed {
573-
b.decorate(id, Decoration::RelaxedPrecision, []);
574-
}
575-
}
576-
b.f_mul(ty, Some(out), square_root, max).unwrap();
577-
})
538+
Arithmetic::Hypot(_op) => {
539+
unreachable!("Replaced by transformer");
578540
}
579-
Arithmetic::Rhypot(op) => {
580-
self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
581-
let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
582-
let zero = b.static_cast(ConstVal::Bit32(0), &Elem::Int(32, false), &out_ty);
583-
let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
584-
let abs_a = b.id();
585-
T::f_abs(b, ty, lhs, abs_a);
586-
let abs_b = b.id();
587-
T::f_abs(b, ty, rhs, abs_b);
588-
let max = b.id();
589-
T::f_max(b, ty, abs_a, abs_b, max);
590-
let min = b.id();
591-
T::f_min(b, ty, abs_a, abs_b, min);
592-
let bool = Elem::Bool.id(b);
593-
let is_max_zero = b.f_ord_equal(bool, None, max, zero).unwrap();
594-
let max_safe = b.id();
595-
b.select(ty, Some(max_safe), is_max_zero, one, max).unwrap();
596-
let t = b.id();
597-
b.f_div(ty, Some(t), min, max_safe).unwrap();
598-
let t_fma = b.gl_fma(ty, t, t, one).unwrap();
599-
let inverse_square_root = b.id();
600-
T::inverse_sqrt(b, ty, t_fma, inverse_square_root);
601-
let ids = [
602-
abs_a,
603-
abs_b,
604-
max,
605-
is_max_zero,
606-
max_safe,
607-
t_fma,
608-
inverse_square_root,
609-
out,
610-
];
611-
for id in ids {
612-
b.mark_uniformity(id, uniform);
613-
if relaxed {
614-
b.decorate(id, Decoration::RelaxedPrecision, []);
615-
}
616-
}
617-
b.f_div(ty, Some(out), inverse_square_root, max).unwrap();
618-
})
541+
Arithmetic::Rhypot(_op) => {
542+
unreachable!("Replaced by transformer");
619543
}
620544
Arithmetic::Sqrt(op) => {
621545
self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {

crates/cubecl-spirv/src/compiler.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::{
3434
item::Item,
3535
lookups::LookupTables,
3636
target::{GLCompute, SpirvTarget},
37-
transformers::{BitwiseTransform, ErfTransform},
37+
transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
3838
};
3939

4040
pub const MAX_VECTORIZATION: u32 = 4;
@@ -215,6 +215,8 @@ impl<Target: SpirvTarget> SpirvCompiler<Target> {
215215
let mut opt = OptimizerBuilder::default()
216216
.with_transformer(ErfTransform)
217217
.with_transformer(BitwiseTransform)
218+
.with_transformer(HypotTransform)
219+
.with_transformer(RhypotTransform)
218220
.with_processor(CheckedIoProcessor::new(self.mode))
219221
.with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
220222
.with_processor(SaturatingArithmeticProcessor::new(true))

0 commit comments

Comments
 (0)