Skip to content

Commit 2d49684

Browse files
committed
Improving hypot/rhypot.
Removing conditional branches when possible.
1 parent ce158a7 commit 2d49684

File tree

4 files changed

+71
-110
lines changed

4 files changed

+71
-110
lines changed

crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs

Lines changed: 31 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -450,93 +450,76 @@ impl<'a> Visitor<'a> {
450450
let (a, b) = self.get_binary_op_variable(hypot.lhs, hypot.rhs);
451451
let abs_a = self.get_absolute_val(hypot.lhs.ty, a);
452452
let abs_b = self.get_absolute_val(hypot.rhs.ty, b);
453-
454-
let max =
455-
self.append_operation_with_result(arith::maxnumf(abs_a, abs_b, self.location));
456453
let zero = self.create_float_constant_from_item(hypot.lhs.ty, 0.0);
457454
let one = self.create_float_constant_from_item(hypot.lhs.ty, 1.0);
458-
let is_zero = self.append_operation_with_result(arith::cmpf(
455+
let max =
456+
self.append_operation_with_result(arith::maxnumf(abs_a, abs_b, self.location));
457+
let is_max_zero = self.append_operation_with_result(arith::cmpf(
459458
self.context,
460459
arith::CmpfPredicate::Oeq,
461460
max,
462461
zero,
463462
self.location,
464463
));
465-
let scale = self.append_operation_with_result(arith::select(
466-
is_zero,
464+
let max_safe = self.append_operation_with_result(arith::select(
465+
is_max_zero,
467466
one,
468467
max,
469468
self.location,
470469
));
471-
let a_scale =
472-
self.append_operation_with_result(arith::divf(abs_a, scale, self.location));
473-
let b_scale =
474-
self.append_operation_with_result(arith::divf(abs_b, scale, self.location));
475-
let a_scale_squared =
476-
self.append_operation_with_result(arith::mulf(a_scale, a_scale, self.location));
477-
let b_scale_squared =
478-
self.append_operation_with_result(arith::mulf(b_scale, b_scale, self.location));
479-
let sum = self.append_operation_with_result(arith::addf(
480-
a_scale_squared,
481-
b_scale_squared,
482-
self.location,
483-
));
470+
let min =
471+
self.append_operation_with_result(arith::minimumf(abs_a, abs_b, self.location));
472+
let t =
473+
self.append_operation_with_result(arith::divf(min, max_safe, self.location));
474+
let t_square = self.append_operation_with_result(arith::mulf(t, t, self.location));
475+
let t_square_plus_one =
476+
self.append_operation_with_result(arith::addf(t_square, one, self.location));
484477
let square_root = self.append_operation_with_result(llvm_ods::intr_sqrt(
485478
self.context,
486-
sum,
487-
self.location,
488-
));
489-
let result = self.append_operation_with_result(arith::mulf(
490-
square_root,
491-
scale,
479+
t_square_plus_one,
492480
self.location,
493481
));
482+
let result =
483+
self.append_operation_with_result(arith::mulf(max, square_root, self.location));
494484

495485
self.insert_variable(out, result);
496486
}
497487
Arithmetic::Rhypot(hypot) => {
498488
let (a, b) = self.get_binary_op_variable(hypot.lhs, hypot.rhs);
499489
let abs_a = self.get_absolute_val(hypot.lhs.ty, a);
500490
let abs_b = self.get_absolute_val(hypot.rhs.ty, b);
501-
502-
let max =
503-
self.append_operation_with_result(arith::maxnumf(abs_a, abs_b, self.location));
504491
let zero = self.create_float_constant_from_item(hypot.lhs.ty, 0.0);
505492
let one = self.create_float_constant_from_item(hypot.lhs.ty, 1.0);
506-
let is_zero = self.append_operation_with_result(arith::cmpf(
493+
let max =
494+
self.append_operation_with_result(arith::maxnumf(abs_a, abs_b, self.location));
495+
let is_max_zero = self.append_operation_with_result(arith::cmpf(
507496
self.context,
508497
arith::CmpfPredicate::Oeq,
509498
max,
510499
zero,
511500
self.location,
512501
));
513-
let scale = self.append_operation_with_result(arith::select(
514-
is_zero,
502+
let max_safe = self.append_operation_with_result(arith::select(
503+
is_max_zero,
515504
one,
516505
max,
517506
self.location,
518507
));
519-
let a_scale =
520-
self.append_operation_with_result(arith::divf(abs_a, scale, self.location));
521-
let b_scale =
522-
self.append_operation_with_result(arith::divf(abs_b, scale, self.location));
523-
let a_scale_squared =
524-
self.append_operation_with_result(arith::mulf(a_scale, a_scale, self.location));
525-
let b_scale_squared =
526-
self.append_operation_with_result(arith::mulf(b_scale, b_scale, self.location));
527-
let sum = self.append_operation_with_result(arith::addf(
528-
a_scale_squared,
529-
b_scale_squared,
530-
self.location,
531-
));
532-
let rsquare_root = self.append_operation_with_result(math_ods::rsqrt(
508+
let min =
509+
self.append_operation_with_result(arith::minimumf(abs_a, abs_b, self.location));
510+
let t =
511+
self.append_operation_with_result(arith::divf(min, max_safe, self.location));
512+
let t_square = self.append_operation_with_result(arith::mulf(t, t, self.location));
513+
let t_square_plus_one =
514+
self.append_operation_with_result(arith::addf(t_square, one, self.location));
515+
let inverse_square_root = self.append_operation_with_result(math_ods::rsqrt(
533516
self.context,
534-
sum,
517+
t_square_plus_one,
535518
self.location,
536519
));
537520
let result = self.append_operation_with_result(arith::divf(
538-
rsquare_root,
539-
scale,
521+
inverse_square_root,
522+
max,
540523
self.location,
541524
));
542525

crates/cubecl-cuda/src/compute/server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ impl ComputeServer for CudaServer {
368368
pixels_per_column: _,
369369
} => {
370370
return Err(LaunchError::Unknown {
371-
context: "CUDA version 12.8 required for tensor map format Im2colWide".into,
371+
context: "CUDA version 12.8 required for tensor map format Im2colWide"
372+
.into(),
372373
});
373374
}
374375
};

crates/cubecl-spirv/src/arithmetic.rs

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -538,44 +538,32 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
538538
Arithmetic::Hypot(op) => {
539539
self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
540540
let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
541+
let zero = b.static_cast(ConstVal::Bit32(0), &Elem::Int(32, false), &out_ty);
542+
let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
541543
let abs_a = b.id();
542544
T::f_abs(b, ty, lhs, abs_a);
543545
let abs_b = b.id();
544546
T::f_abs(b, ty, rhs, abs_b);
545547
let max = b.id();
546548
T::f_max(b, ty, abs_a, abs_b, max);
547-
let zero = b.static_cast(ConstVal::Bit32(0), &Elem::Int(32, false), &out_ty);
548-
let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
549+
let min = b.id();
550+
T::f_min(b, ty, abs_a, abs_b, min);
549551
let bool = Elem::Bool.id(b);
550-
let is_zero = b.f_ord_equal(bool, None, max, zero).unwrap();
551-
let scale = b.id();
552-
b.select(ty, Some(scale), is_zero, one, max).unwrap();
553-
let a_scaled = b.id();
554-
b.f_div(ty, Some(a_scaled), abs_a, scale).unwrap();
555-
let b_scaled = b.id();
556-
b.f_div(ty, Some(b_scaled), abs_b, scale).unwrap();
557-
let a_scale_squared = b.id();
558-
b.f_mul(ty, Some(a_scale_squared), a_scaled, a_scaled)
559-
.unwrap();
560-
let b_scale_squared = b.id();
561-
b.f_mul(ty, Some(b_scale_squared), b_scaled, b_scaled)
562-
.unwrap();
563-
let sum = b.id();
564-
b.f_add(ty, Some(sum), a_scale_squared, b_scale_squared)
565-
.unwrap();
552+
let is_max_zero = b.f_ord_equal(bool, None, max, zero).unwrap();
553+
let max_safe = b.id();
554+
b.select(ty, Some(max_safe), is_max_zero, one, max).unwrap();
555+
let t = b.id();
556+
b.f_div(ty, Some(t), min, max_safe).unwrap();
557+
let t_fma = b.gl_fma(ty, t, t, one).unwrap();
566558
let square_root = b.id();
567-
T::sqrt(b, ty, sum, square_root);
559+
T::sqrt(b, ty, t_fma, square_root);
568560
let ids = [
569561
abs_a,
570562
abs_b,
571563
max,
572-
is_zero,
573-
scale,
574-
a_scaled,
575-
b_scaled,
576-
a_scale_squared,
577-
b_scale_squared,
578-
sum,
564+
is_max_zero,
565+
max_safe,
566+
t_fma,
579567
square_root,
580568
out,
581569
];
@@ -585,51 +573,39 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
585573
b.decorate(id, Decoration::RelaxedPrecision, []);
586574
}
587575
}
588-
b.f_mul(ty, Some(out), square_root, scale).unwrap();
576+
b.f_mul(ty, Some(out), square_root, max).unwrap();
589577
})
590578
}
591579
Arithmetic::Rhypot(op) => {
592580
self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
593581
let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
582+
let zero = b.static_cast(ConstVal::Bit32(0), &Elem::Int(32, false), &out_ty);
583+
let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
594584
let abs_a = b.id();
595585
T::f_abs(b, ty, lhs, abs_a);
596586
let abs_b = b.id();
597587
T::f_abs(b, ty, rhs, abs_b);
598588
let max = b.id();
599589
T::f_max(b, ty, abs_a, abs_b, max);
600-
let zero = b.static_cast(ConstVal::Bit32(0), &Elem::Int(32, false), &out_ty);
601-
let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
590+
let min = b.id();
591+
T::f_min(b, ty, abs_a, abs_b, min);
602592
let bool = Elem::Bool.id(b);
603-
let is_zero = b.f_ord_equal(bool, None, max, zero).unwrap();
604-
let scale = b.id();
605-
b.select(ty, Some(scale), is_zero, one, max).unwrap();
606-
let a_scaled = b.id();
607-
b.f_div(ty, Some(a_scaled), abs_a, scale).unwrap();
608-
let b_scaled = b.id();
609-
b.f_div(ty, Some(b_scaled), abs_b, scale).unwrap();
610-
let a_scale_squared = b.id();
611-
b.f_mul(ty, Some(a_scale_squared), a_scaled, a_scaled)
612-
.unwrap();
613-
let b_scale_squared = b.id();
614-
b.f_mul(ty, Some(b_scale_squared), b_scaled, b_scaled)
615-
.unwrap();
616-
let sum = b.id();
617-
b.f_add(ty, Some(sum), a_scale_squared, b_scale_squared)
618-
.unwrap();
619-
let rsquare_root = b.id();
620-
T::inverse_sqrt(b, ty, sum, rsquare_root);
593+
let is_max_zero = b.f_ord_equal(bool, None, max, zero).unwrap();
594+
let max_safe = b.id();
595+
b.select(ty, Some(max_safe), is_max_zero, one, max).unwrap();
596+
let t = b.id();
597+
b.f_div(ty, Some(t), min, max_safe).unwrap();
598+
let t_fma = b.gl_fma(ty, t, t, one).unwrap();
599+
let inverse_square_root = b.id();
600+
T::inverse_sqrt(b, ty, t_fma, inverse_square_root);
621601
let ids = [
622602
abs_a,
623603
abs_b,
624604
max,
625-
is_zero,
626-
scale,
627-
a_scaled,
628-
b_scaled,
629-
a_scale_squared,
630-
b_scale_squared,
631-
sum,
632-
rsquare_root,
605+
is_max_zero,
606+
max_safe,
607+
t_fma,
608+
inverse_square_root,
633609
out,
634610
];
635611
for id in ids {
@@ -638,7 +614,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
638614
b.decorate(id, Decoration::RelaxedPrecision, []);
639615
}
640616
}
641-
b.f_div(ty, Some(out), rsquare_root, scale).unwrap();
617+
b.f_div(ty, Some(out), inverse_square_root, max).unwrap();
642618
})
643619
}
644620
Arithmetic::Sqrt(op) => {

crates/cubecl-wgpu/src/compiler/wgsl/extension.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,13 @@ fn format_hypot_primitive(
352352
f,
353353
"
354354
fn {function_name}(lhs: {elem}, rhs: {elem}) -> {elem} {{
355-
if (lhs == 0.0) {{ return abs(rhs); }}
356-
if (rhs == 0.0) {{ return abs(lhs); }}
357355
let a = abs(lhs);
358356
let b = abs(rhs);
359357
let max_val = max(a, b);
358+
var max_val_safe = max_val;
359+
if (max_val == 0.0) {{ max_val_safe = 1.0; }}
360360
let min_val = min(a, b);
361-
let t = min_val / max_val;
361+
let t = min_val / max_val_safe;
362362
363363
return max_val * sqrt(fma(t, t, 1.0));
364364
}}
@@ -378,10 +378,11 @@ fn format_rhypot_primitive(
378378
fn {function_name}(lhs: {elem}, rhs: {elem}) -> {elem} {{
379379
let a = abs(lhs);
380380
let b = abs(rhs);
381-
if (a == 0.0 && b == 0.0) {{ return bitcast<f32>(0x7F800000u); }}
382381
let max_val = max(a, b);
382+
var max_val_safe = max_val;
383+
if (max_val == 0.0) {{ max_val_safe = 1.0; }}
383384
let min_val = min(a, b);
384-
let t = min_val / max_val;
385+
let t = min_val / max_val_safe;
385386
386387
return inverseSqrt(fma(t, t, 1.0)) / max_val;
387388
}}

0 commit comments

Comments
 (0)