Skip to content

Commit bd31af7

Browse files
committed
Add support for hypot and reciprocal hypot.
1 parent 3c94174 commit bd31af7

File tree

25 files changed

+629
-66
lines changed

25 files changed

+629
-66
lines changed

crates/cubecl-core/src/frontend/element/float.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ pub trait Float:
4040
+ ArcTan2
4141
+ Powf
4242
+ Powi<i32>
43+
+ Hypot
44+
+ Rhypot
4345
+ Sqrt
4446
+ InverseSqrt
4547
+ Round

crates/cubecl-core/src/frontend/element/float/typemap.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ impl<const POS: u8> Radians for ElemExpand<POS> {}
262262
impl<const POS: u8> ArcTan2 for ElemExpand<POS> {}
263263
impl<const POS: u8> Powf for ElemExpand<POS> {}
264264
impl<const POS: u8, I: CubePrimitive> Powi<I> for ElemExpand<POS> {}
265+
impl<const POS: u8> Hypot for ElemExpand<POS> {}
266+
impl<const POS: u8> Rhypot for ElemExpand<POS> {}
265267
impl<const POS: u8> Sqrt for ElemExpand<POS> {}
266268
impl<const POS: u8> InverseSqrt for ElemExpand<POS> {}
267269
impl<const POS: u8> Round for ElemExpand<POS> {}

crates/cubecl-core/src/frontend/operation/binary.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,31 @@ impl_binary_func!(
252252
f32,
253253
f64
254254
);
255+
256+
impl_binary_func!(
257+
Hypot,
258+
hypot,
259+
Arithmetic::Hypot,
260+
f16,
261+
bf16,
262+
flex32,
263+
tf32,
264+
f32,
265+
f64
266+
);
267+
268+
impl_binary_func!(
269+
Rhypot,
270+
rhypot,
271+
Arithmetic::Rhypot,
272+
f16,
273+
bf16,
274+
flex32,
275+
tf32,
276+
f32,
277+
f64
278+
);
279+
255280
impl_binary_func!(
256281
ArcTan2,
257282
atan2,

crates/cubecl-core/src/runtime_tests/binary.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::f32;
12
use std::{fmt::Display, sync::LazyLock};
23

34
use crate::{self as cubecl, as_type};
@@ -30,7 +31,11 @@ pub(crate) fn assert_equals_approx<
3031
// account for lower precision at higher values
3132
let allowed_error = F::new((epsilon * e.to_f32().unwrap().abs()).max(epsilon));
3233
assert!(
33-
(*a - *e).abs() < allowed_error || (a.is_nan() && e.is_nan()),
34+
(*a - *e).abs() < allowed_error
35+
|| (a.is_nan() && e.is_nan())
36+
|| (a.is_infinite()
37+
&& e.is_infinite()
38+
&& a.is_sign_positive() == e.is_sign_positive()),
3439
"Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={}
3540
index: {}
3641
actual: {:?}
@@ -184,6 +189,64 @@ test_binary_impl!(
184189
]
185190
);
186191

192+
test_binary_impl!(
193+
test_hypot,
194+
F,
195+
F::hypot,
196+
[
197+
{
198+
input_vectorization: 1,
199+
out_vectorization: 1,
200+
lhs: as_type![F: 3., 0., 5., 0.],
201+
rhs: as_type![F: 4., 5., 0., 0.],
202+
expected: as_type![F: 5., 5., 5., 0.]
203+
},
204+
{
205+
input_vectorization: 2,
206+
out_vectorization: 2,
207+
lhs: as_type![F: 3., 0., 5., 8.],
208+
rhs: as_type![F: 4., 5., 0., 15.],
209+
expected: as_type![F: 5., 5., 5., 17.]
210+
},
211+
{
212+
input_vectorization: 4,
213+
out_vectorization: 4,
214+
lhs: as_type![F: -3., 0., -5., -8.],
215+
rhs: as_type![F: -4., -5., 0., 15.],
216+
expected: as_type![F: 5., 5., 5., 17.]
217+
}
218+
]
219+
);
220+
221+
test_binary_impl!(
222+
test_rhypot,
223+
F,
224+
F::rhypot,
225+
[
226+
{
227+
input_vectorization: 1,
228+
out_vectorization: 1,
229+
lhs: as_type![F: 3., 0., 5., 0.],
230+
rhs: as_type![F: 4., 5., 0., 0.],
231+
expected: &[F::new(0.2), F::new(0.2), F::new(0.2), F::INFINITY]
232+
},
233+
{
234+
input_vectorization: 2,
235+
out_vectorization: 2,
236+
lhs: as_type![F: 3., 0., 5., 0.3],
237+
rhs: as_type![F: 4., 5., 0., 0.4],
238+
expected: as_type![F: 0.2, 0.2, 0.2, 2.]
239+
},
240+
{
241+
input_vectorization: 4,
242+
out_vectorization: 4,
243+
lhs: as_type![F: 0., 0., -5., -0.3],
244+
rhs: as_type![F: -1., -5., 0., -0.4],
245+
expected: as_type![F: 1., 0.2, 0.2, 2.]
246+
}
247+
]
248+
);
249+
187250
#[cube(launch_unchecked)]
188251
fn test_powi_kernel<F: Float>(
189252
lhs: &Array<Line<F>>,
@@ -354,6 +417,8 @@ macro_rules! testgen_binary {
354417

355418
add_test!(test_dot);
356419
add_test!(test_powf);
420+
add_test!(test_hypot);
421+
add_test!(test_rhypot);
357422
add_test!(test_powi);
358423
add_test!(test_atan2);
359424
}

crates/cubecl-core/src/runtime_tests/cmma.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,8 @@ pub fn test_simple_1_expected() -> Vec<f32> {
482482
// let lhs: Vec<f16> = (0..64).map(|i| f16::from_f32(i as f32)).collect();
483483
// let rhs: Vec<f16> = (0..64).map(|i| f16::from_f32((i % 8) as f32)).collect();
484484

485-
// let lhs = client.create(f16::as_bytes(&lhs));
486-
// let rhs = client.create(f16::as_bytes(&rhs));
485+
// let lhs = client.create_from_slice(f16::as_bytes(&lhs));
486+
// let rhs = client.create_from_slice(f16::as_bytes(&rhs));
487487
// let out = client.empty(core::mem::size_of::<f16>() * 64);
488488

489489
// unsafe {

crates/cubecl-core/src/runtime_tests/unary.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use std::f32::consts::PI;
1+
use core::f32;
2+
use core::f32::consts::PI;
23
use std::fmt::Display;
34

45
use crate::{self as cubecl, as_type};

crates/cubecl-cpp/src/shared/base.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,12 @@ impl<D: Dialect> CppCompiler<D> {
11761176
gpu::Arithmetic::Powi(op) => {
11771177
instructions.push(Instruction::Powi(self.compile_binary(op, out)))
11781178
}
1179+
gpu::Arithmetic::Hypot(op) => {
1180+
instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
1181+
}
1182+
gpu::Arithmetic::Rhypot(op) => {
1183+
instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
1184+
}
11791185
gpu::Arithmetic::Sqrt(op) => {
11801186
let op = self.compile_unary(op, out);
11811187
instructions.push(self.select_fast_float(

crates/cubecl-cpp/src/shared/binary.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ impl<D: Dialect> Binary<D> for Powi {
324324
f.write_str("};\n")
325325
}
326326
}
327-
328327
pub struct ArcTan2;
329328

330329
impl<D: Dialect> Binary<D> for ArcTan2 {
@@ -370,6 +369,112 @@ impl<D: Dialect> Binary<D> for ArcTan2 {
370369
}
371370
}
372371

372+
pub struct Hypot;
373+
374+
impl<D: Dialect> Binary<D> for Hypot {
375+
// Hypot doesn't support half and no half equivalent exists
376+
fn format_scalar<Lhs, Rhs>(
377+
f: &mut Formatter<'_>,
378+
lhs: Lhs,
379+
rhs: Rhs,
380+
item: Item<D>,
381+
) -> std::fmt::Result
382+
where
383+
Lhs: Component<D>,
384+
Rhs: Component<D>,
385+
{
386+
let elem = item.elem;
387+
let lhs = lhs.to_string();
388+
let rhs = rhs.to_string();
389+
match elem {
390+
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
391+
let lhs = format!("float({lhs})");
392+
let rhs = format!("float({rhs})");
393+
write!(f, "{elem}(")?;
394+
D::compile_instruction_hypot(f, &lhs, &rhs, Elem::F32)?;
395+
write!(f, ")")
396+
}
397+
_ => D::compile_instruction_hypot(f, &lhs, &rhs, elem),
398+
}
399+
}
400+
401+
// Hypot doesn't support half and no half equivalent exists
402+
fn unroll_vec(
403+
f: &mut Formatter<'_>,
404+
lhs: &Variable<D>,
405+
rhs: &Variable<D>,
406+
out: &Variable<D>,
407+
) -> core::fmt::Result {
408+
let item_out = out.item();
409+
let index = out.item().vectorization;
410+
411+
let out = out.fmt_left();
412+
writeln!(f, "{out} = {item_out}{{")?;
413+
for i in 0..index {
414+
let lhsi = lhs.index(i);
415+
let rhsi = rhs.index(i);
416+
417+
Self::format_scalar(f, lhsi, rhsi, item_out)?;
418+
f.write_str(", ")?;
419+
}
420+
421+
f.write_str("};\n")
422+
}
423+
}
424+
425+
pub struct Rhypot;
426+
427+
impl<D: Dialect> Binary<D> for Rhypot {
428+
// Rhypot doesn't support half and no half equivalent exists
429+
fn format_scalar<Lhs, Rhs>(
430+
f: &mut Formatter<'_>,
431+
lhs: Lhs,
432+
rhs: Rhs,
433+
item: Item<D>,
434+
) -> std::fmt::Result
435+
where
436+
Lhs: Component<D>,
437+
Rhs: Component<D>,
438+
{
439+
let elem = item.elem;
440+
let lhs = lhs.to_string();
441+
let rhs = rhs.to_string();
442+
match elem {
443+
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
444+
let lhs = format!("float({lhs})");
445+
let rhs = format!("float({rhs})");
446+
write!(f, "{elem}(")?;
447+
D::compile_instruction_rhypot(f, &lhs, &rhs, Elem::F32)?;
448+
write!(f, ")")
449+
}
450+
_ => D::compile_instruction_rhypot(f, &lhs, &rhs, elem),
451+
}
452+
}
453+
454+
// Rhypot doesn't support half and no half equivalent exists
455+
fn unroll_vec(
456+
f: &mut Formatter<'_>,
457+
lhs: &Variable<D>,
458+
rhs: &Variable<D>,
459+
out: &Variable<D>,
460+
) -> core::fmt::Result {
461+
let item_out = out.item();
462+
let index = out.item().vectorization;
463+
464+
let out = out.fmt_left();
465+
writeln!(f, "{out} = {item_out}{{")?;
466+
for i in 0..index {
467+
let lhsi = lhs.index(i);
468+
let rhsi = rhs.index(i);
469+
470+
Self::format_scalar(f, lhsi, rhsi, item_out)?;
471+
f.write_str(", ")?;
472+
}
473+
474+
f.write_str("};\n")
475+
}
476+
}
477+
373478
pub struct Max;
374479

375480
impl<D: Dialect> Binary<D> for Max {

crates/cubecl-cpp/src/shared/dialect.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,32 @@ pub trait DialectInstructions<D: Dialect> {
633633
}
634634
}
635635

636+
fn compile_instruction_hypot(
637+
f: &mut std::fmt::Formatter<'_>,
638+
lhs: &str,
639+
rhs: &str,
640+
elem: Elem<D>,
641+
) -> std::fmt::Result {
642+
match elem {
643+
Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
644+
Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
645+
_ => panic!("Unsupported type for hypot"),
646+
}
647+
}
648+
649+
fn compile_instruction_rhypot(
650+
f: &mut std::fmt::Formatter<'_>,
651+
lhs: &str,
652+
rhs: &str,
653+
elem: Elem<D>,
654+
) -> std::fmt::Result {
655+
match elem {
656+
Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
657+
Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
658+
_ => panic!("Unsupported type for hypot"),
659+
}
660+
}
661+
636662
fn compile_instruction_half_function_name_prefix() -> &'static str {
637663
"h"
638664
}

crates/cubecl-cpp/src/shared/instruction.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ pub enum Instruction<D: Dialect> {
187187
Powf(BinaryInstruction<D>),
188188
FastPowf(BinaryInstruction<D>),
189189
Powi(BinaryInstruction<D>),
190+
Hypot(BinaryInstruction<D>),
191+
Rhypot(BinaryInstruction<D>),
190192
Sqrt(UnaryInstruction<D>),
191193
FastSqrt(UnaryInstruction<D>),
192194
InverseSqrt(UnaryInstruction<D>),
@@ -564,6 +566,8 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
564566
Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
565567
Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
566568
Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
569+
Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
570+
Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
567571
Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
568572
Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
569573
Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),

0 commit comments

Comments
 (0)