Skip to content

Commit 97d25c8

Browse files
committed
refactor: Deduplicate pow implementations
`strict_pow` can be implemented in terms of `checked_pow`, and `wrapping_pow` can be implemented in terms of `overflowing_pow`.
1 parent 20f67bc commit 97d25c8

File tree

2 files changed

+37
-82
lines changed

2 files changed

+37
-82
lines changed

library/core/src/num/overflow_panic.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ pub(super) const fn rem() -> ! {
3232
panic!("attempt to calculate the remainder with overflow")
3333
}
3434

35+
#[cold]
36+
#[track_caller]
37+
pub(super) const fn pow() -> ! {
38+
panic!("attempt to calculate the power with overflow")
39+
}
40+
3541
#[cold]
3642
#[track_caller]
3743
pub(super) const fn neg() -> ! {

library/core/src/num/uint_macros.rs

Lines changed: 31 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,23 +2117,10 @@ macro_rules! uint_impl {
21172117
without modifying the original"]
21182118
#[inline]
21192119
#[track_caller]
2120-
pub const fn strict_pow(self, mut exp: u32) -> Self {
2121-
if exp == 0 {
2122-
return 1;
2123-
}
2124-
let mut base = self;
2125-
let mut acc: Self = 1;
2126-
2127-
loop {
2128-
if (exp & 1) == 1 {
2129-
acc = acc.strict_mul(base);
2130-
// since exp!=0, finally the exp must be 1.
2131-
if exp == 1 {
2132-
return acc;
2133-
}
2134-
}
2135-
exp /= 2;
2136-
base = base.strict_mul(base);
2120+
pub const fn strict_pow(self, exp: u32) -> Self {
2121+
match self.checked_pow(exp) {
2122+
None => overflow_panic::pow(),
2123+
Some(a) => a,
21372124
}
21382125
}
21392126

@@ -2593,43 +2580,9 @@ macro_rules! uint_impl {
25932580
#[must_use = "this returns the result of the operation, \
25942581
without modifying the original"]
25952582
#[inline]
2596-
pub const fn wrapping_pow(self, mut exp: u32) -> Self {
2597-
if exp == 0 {
2598-
return 1;
2599-
}
2600-
let mut base = self;
2601-
let mut acc: Self = 1;
2602-
2603-
if intrinsics::is_val_statically_known(exp) {
2604-
while exp > 1 {
2605-
if (exp & 1) == 1 {
2606-
acc = acc.wrapping_mul(base);
2607-
}
2608-
exp /= 2;
2609-
base = base.wrapping_mul(base);
2610-
}
2611-
2612-
// since exp!=0, finally the exp must be 1.
2613-
// Deal with the final bit of the exponent separately, since
2614-
// squaring the base afterwards is not necessary.
2615-
acc.wrapping_mul(base)
2616-
} else {
2617-
// This is faster than the above when the exponent is not known
2618-
// at compile time. We can't use the same code for the constant
2619-
// exponent case because LLVM is currently unable to unroll
2620-
// this loop.
2621-
loop {
2622-
if (exp & 1) == 1 {
2623-
acc = acc.wrapping_mul(base);
2624-
// since exp!=0, finally the exp must be 1.
2625-
if exp == 1 {
2626-
return acc;
2627-
}
2628-
}
2629-
exp /= 2;
2630-
base = base.wrapping_mul(base);
2631-
}
2632-
}
2583+
pub const fn wrapping_pow(self, exp: u32) -> Self {
2584+
let (a, _) = self.overflowing_pow(exp);
2585+
a
26332586
}
26342587

26352588
/// Calculates `self` + `rhs`.
@@ -3269,30 +3222,26 @@ macro_rules! uint_impl {
32693222
without modifying the original"]
32703223
#[inline]
32713224
pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) {
3272-
if exp == 0{
3273-
return (1,false);
3225+
if exp == 0 {
3226+
return (1, false);
32743227
}
32753228
let mut base = self;
32763229
let mut acc: Self = 1;
3277-
let mut overflown = false;
3278-
// Scratch space for storing results of overflowing_mul.
3279-
let mut r;
3230+
let mut overflow = false;
3231+
let mut tmp_overflow;
32803232

32813233
loop {
32823234
if (exp & 1) == 1 {
3283-
r = acc.overflowing_mul(base);
3235+
(acc, tmp_overflow) = acc.overflowing_mul(base);
3236+
overflow |= tmp_overflow;
32843237
// since exp!=0, finally the exp must be 1.
32853238
if exp == 1 {
3286-
r.1 |= overflown;
3287-
return r;
3239+
return (acc, overflow);
32883240
}
3289-
acc = r.0;
3290-
overflown |= r.1;
32913241
}
32923242
exp /= 2;
3293-
r = base.overflowing_mul(base);
3294-
base = r.0;
3295-
overflown |= r.1;
3243+
(base, tmp_overflow) = base.overflowing_mul(base);
3244+
overflow |= tmp_overflow;
32963245
}
32973246
}
32983247

@@ -3329,23 +3278,23 @@ macro_rules! uint_impl {
33293278
// Deal with the final bit of the exponent separately, since
33303279
// squaring the base afterwards is not necessary and may cause a
33313280
// needless overflow.
3332-
acc * base
3333-
} else {
3334-
// This is faster than the above when the exponent is not known
3335-
// at compile time. We can't use the same code for the constant
3336-
// exponent case because LLVM is currently unable to unroll
3337-
// this loop.
3338-
loop {
3339-
if (exp & 1) == 1 {
3340-
acc = acc * base;
3341-
// since exp!=0, finally the exp must be 1.
3342-
if exp == 1 {
3343-
return acc;
3344-
}
3281+
return acc * base;
3282+
}
3283+
3284+
// This is faster than the above when the exponent is not known
3285+
// at compile time. We can't use the same code for the constant
3286+
// exponent case because LLVM is currently unable to unroll
3287+
// this loop.
3288+
loop {
3289+
if (exp & 1) == 1 {
3290+
acc = acc * base;
3291+
// since exp!=0, finally the exp must be 1.
3292+
if exp == 1 {
3293+
return acc;
33453294
}
3346-
exp /= 2;
3347-
base = base * base;
33483295
}
3296+
exp /= 2;
3297+
base = base * base;
33493298
}
33503299
}
33513300

0 commit comments

Comments
 (0)