Skip to content

Commit c2f36c5

Browse files
committed
optimize: pow when base is a power of two
if base == 2 ** k, then (2 ** k) ** n == 2 ** (k * n) == 1 << (k * n)
1 parent 759e81a commit c2f36c5

File tree

5 files changed

+137
-48
lines changed

5 files changed

+137
-48
lines changed

library/core/src/num/uint_macros.rs

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,12 +2072,20 @@ macro_rules! uint_impl {
20722072
without modifying the original"]
20732073
#[inline]
20742074
pub const fn checked_pow(self, mut exp: u32) -> Option<Self> {
2075-
if exp == 0 {
2076-
return Some(1);
2077-
}
20782075
let mut base = self;
20792076
let mut acc: Self = 1;
20802077

2078+
if intrinsics::is_val_statically_known(base) && base.is_power_of_two() {
2079+
// change of base:
2080+
// if base == 2 ** k, then
2081+
// (2 ** k) ** n
2082+
// == 2 ** (k * n)
2083+
// == 1 << (k * n)
2084+
let k = base.ilog2();
2085+
let shift = try_opt!(k.checked_mul(exp));
2086+
return (1 as Self).checked_shl(shift);
2087+
}
2088+
20812089
if intrinsics::is_val_statically_known(exp) {
20822090
while exp > 1 {
20832091
if (exp & 1) == 1 {
@@ -2094,6 +2102,10 @@ macro_rules! uint_impl {
20942102
return acc.checked_mul(base);
20952103
}
20962104

2105+
if exp == 0 {
2106+
return Some(1);
2107+
}
2108+
20972109
loop {
20982110
if (exp & 1) == 1 {
20992111
acc = try_opt!(acc.checked_mul(base));
@@ -3238,14 +3250,24 @@ macro_rules! uint_impl {
32383250
without modifying the original"]
32393251
#[inline]
32403252
pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) {
3241-
if exp == 0 {
3242-
return (1, false);
3243-
}
32443253
let mut base = self;
32453254
let mut acc: Self = 1;
32463255
let mut overflow = false;
32473256
let mut tmp_overflow;
32483257

3258+
if intrinsics::is_val_statically_known(base) && base.is_power_of_two() {
3259+
// change of base:
3260+
// if base == 2 ** k, then
3261+
// (2 ** k) ** n
3262+
// == 2 ** (k * n)
3263+
// == 1 << (k * n)
3264+
let k = base.ilog2();
3265+
let Some(shift) = k.checked_mul(exp) else {
3266+
return (0, true)
3267+
};
3268+
return ((1 as Self).unbounded_shl(shift), shift >= Self::BITS)
3269+
}
3270+
32493271
if intrinsics::is_val_statically_known(exp) {
32503272
while exp > 1 {
32513273
if (exp & 1) == 1 {
@@ -3266,6 +3288,10 @@ macro_rules! uint_impl {
32663288
return (acc, overflow);
32673289
}
32683290

3291+
if exp == 0 {
3292+
return (1, false);
3293+
}
3294+
32693295
loop {
32703296
if (exp & 1) == 1 {
32713297
(acc, tmp_overflow) = acc.overflowing_mul(base);
@@ -3295,12 +3321,25 @@ macro_rules! uint_impl {
32953321
#[inline]
32963322
#[rustc_inherit_overflow_checks]
32973323
pub const fn pow(self, mut exp: u32) -> Self {
3298-
if exp == 0 {
3299-
return 1;
3300-
}
33013324
let mut base = self;
33023325
let mut acc = 1;
33033326

3327+
if intrinsics::is_val_statically_known(base) && base.is_power_of_two() {
3328+
// change of base:
3329+
// if base == 2 ** k, then
3330+
// (2 ** k) ** n
3331+
// == 2 ** (k * n)
3332+
// == 1 << (k * n)
3333+
let k = base.ilog2();
3334+
// Panic on overflow if `-C overflow-checks` is enabled.
3335+
// Otherwise will be optimized out
3336+
let _overflow_check = (1 as Self) << (k * exp);
3337+
let Some(shift) = k.checked_mul(exp) else {
3338+
return 0
3339+
};
3340+
return (1 as Self).unbounded_shl(shift)
3341+
}
3342+
33043343
if intrinsics::is_val_statically_known(exp) {
33053344
while exp > 1 {
33063345
if (exp & 1) == 1 {
@@ -3317,6 +3356,10 @@ macro_rules! uint_impl {
33173356
return acc * base;
33183357
}
33193358

3359+
if exp == 0 {
3360+
return 1;
3361+
}
3362+
33203363
// This is faster than the above when the exponent is not known
33213364
// at compile time. We can't use the same code for the constant
33223365
// exponent case because LLVM is currently unable to unroll

library/coretests/tests/num/uint_macros.rs

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,19 @@ macro_rules! uint_module {
271271
assert_eq!(from_str::<$T>("x"), None);
272272
}
273273

274-
// Not const because `catch_unwind` is not const yet.
274+
// Not const because overflow always panics during const evaluation, and
275+
// `catch_unwind` is not const yet.
275276
#[test]
276-
fn strict_pow() {
277+
fn panicking_pow() {
277278
use std::panic::catch_unwind;
278279

279280
{
280281
const R: $T = 0;
282+
assert_eq!(R.pow(0), 1 as $T);
283+
assert_eq!(R.pow(1), 0 as $T);
284+
assert_eq!(R.pow(2), 0 as $T);
285+
assert_eq!(R.pow(128), 0 as $T);
286+
281287
assert_eq!(R.strict_pow(0), 1 as $T);
282288
assert_eq!(R.strict_pow(1), 0 as $T);
283289
assert_eq!(R.strict_pow(2), 0 as $T);
@@ -286,6 +292,11 @@ macro_rules! uint_module {
286292

287293
{
288294
const R: $T = 1;
295+
assert_eq!(R.pow(0), 1 as $T);
296+
assert_eq!(R.pow(1), 1 as $T);
297+
assert_eq!(R.pow(2), 1 as $T);
298+
assert_eq!(R.pow(128), 1 as $T);
299+
289300
assert_eq!(R.strict_pow(0), 1 as $T);
290301
assert_eq!(R.strict_pow(1), 1 as $T);
291302
assert_eq!(R.strict_pow(2), 1 as $T);
@@ -294,48 +305,24 @@ macro_rules! uint_module {
294305

295306
{
296307
const R: $T = 2;
308+
assert_eq!(R.pow(0), 1 as $T);
309+
assert_eq!(R.pow(1), 2 as $T);
310+
assert_eq!(R.pow(2), 4 as $T);
311+
assert_eq!(R.pow(128), 0 as $T);
312+
assert_eq!(R.pow(129), 0 as $T);
313+
297314
assert_eq!(R.strict_pow(0), 1 as $T);
298315
assert_eq!(R.strict_pow(1), 2 as $T);
299316
assert_eq!(R.strict_pow(2), 4 as $T);
300317
assert!(catch_unwind(|| R.strict_pow(128)).is_err());
301318
}
302319

320+
// If `k * exp` wraps around into a valid shift amount, `pow` should still return 0/panic
303321
{
304-
const R: $T = $T::MAX;
305-
assert_eq!(R.strict_pow(0), 1 as $T);
306-
assert_eq!(R.strict_pow(1), R as $T);
307-
assert!(catch_unwind(|| R.strict_pow(2)).is_err());
308-
assert!(catch_unwind(|| R.strict_pow(128)).is_err());
309-
}
310-
}
311-
312-
// Not const because overflow always panics during const evaluation, and
313-
// `catch_unwind` is not const yet.
314-
#[test]
315-
fn pow() {
316-
{
317-
const R: $T = 0;
318-
assert_eq!(R.pow(0), 1 as $T);
319-
assert_eq!(R.pow(1), 0 as $T);
320-
assert_eq!(R.pow(2), 0 as $T);
321-
assert_eq!(R.pow(128), 0 as $T);
322-
}
323-
324-
{
325-
const R: $T = 1;
326-
assert_eq!(R.pow(0), 1 as $T);
327-
assert_eq!(R.pow(1), 1 as $T);
328-
assert_eq!(R.pow(2), 1 as $T);
329-
assert_eq!(R.pow(128), 1 as $T);
330-
}
331-
332-
{
333-
const R: $T = 2;
334-
assert_eq!(R.pow(0), 1 as $T);
335-
assert_eq!(R.pow(1), 2 as $T);
336-
assert_eq!(R.pow(2), 4 as $T);
337-
assert_eq!(R.pow(128), 0 as $T);
338-
assert_eq!(R.pow(129), 0 as $T);
322+
const R: $T = 4;
323+
const HALF: u32 = u32::MAX / 2 + 1;
324+
assert_eq!(R.pow(HALF), 0 as $T);
325+
assert!(catch_unwind(|| R.strict_pow(HALF)).is_err());
339326
}
340327

341328
{
@@ -345,6 +332,11 @@ macro_rules! uint_module {
345332
assert_eq!(R.pow(2), 1 as $T);
346333
assert_eq!(R.pow(128), 1 as $T);
347334
assert_eq!(R.pow(129), R as $T);
335+
336+
assert_eq!(R.strict_pow(0), 1 as $T);
337+
assert_eq!(R.strict_pow(1), R as $T);
338+
assert!(catch_unwind(|| R.strict_pow(2)).is_err());
339+
assert!(catch_unwind(|| R.strict_pow(128)).is_err());
348340
}
349341
}
350342

@@ -443,6 +435,17 @@ macro_rules! uint_module {
443435
assert_eq_const_safe!($T: R.saturating_pow(129), $T::MAX);
444436
}
445437

438+
// overflow in the shift caclculation should result in the final
439+
// result being 0 rather than accidentally succeeding due to a
440+
// shift within the word size
441+
// ie `4 ** 0x8000_0000` should give 0 rather than 1 << 0
442+
{
443+
const R: $T = 4;
444+
const HALF: u32 = u32::MAX / 2 + 1;
445+
assert_eq_const_safe!($T: R.wrapping_pow(HALF), 0 as $T);
446+
assert_eq_const_safe!(($T, bool): R.overflowing_pow(HALF), (0 as $T, true));
447+
}
448+
446449
{
447450
const R: $T = $T::MAX;
448451
assert_eq_const_safe!($T: R.wrapping_pow(0), 1 as $T);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//@ compile-flags: -Copt-level=3
2+
// Test that `pow` can use a faster implementation when `base` is a
3+
// known power of two
4+
5+
#![crate_type = "lib"]
6+
7+
// CHECK-LABEL: @pow2
8+
#[no_mangle]
9+
pub fn pow2(exp: u32) -> u32 {
10+
// CHECK: %[[OVERFLOW:.+]] = icmp ugt i32 %exp, 31
11+
// CHECK: %[[POW:.+]] = shl nuw i32 1, %exp
12+
// CHECK: %[[RET:.+]] = select i1 %[[OVERFLOW]], i32 0, i32 %[[POW]]
13+
// CHECK: ret i32 %[[RET]]
14+
2u32.pow(exp)
15+
}
16+
17+
// 4 ** n == 2 ** (2 * n) == 1 << (2 * n)
18+
// CHECK-LABEL: @pow4
19+
#[no_mangle]
20+
pub fn pow4(exp: u32) -> u32 {
21+
// CHECK: %[[SHIFT_AMOUNT:.+]] = shl i32 %exp, 1
22+
// CHECK: %[[ICMP1:.+]] = icmp slt i32 %exp, 0
23+
// CHECK: %[[ICMP2:.+]] = icmp ugt i32 %[[SHIFT_AMOUNT]], 31
24+
// CHECK: %[[OVERFLOW:.+]] = or i1 %[[ICMP1]], %[[ICMP2]]
25+
// CHECK: %[[POW:.+]] = shl nuw i32 1, %[[SHIFT_AMOUNT]]
26+
// CHECK: %[[RET:.+]] = select i1 %[[OVERFLOW]], i32 0, i32 %[[POW]]
27+
// CHECK: ret i32 %[[RET]]
28+
4u32.pow(exp)
29+
}
30+
31+
// 16 ** n == 2 ** (4 * n) == 1 << (4 * n)
32+
// CHECK-LABEL: @pow16
33+
#[no_mangle]
34+
pub fn pow16(exp: u32) -> u32 {
35+
// CHECK: %[[SHIFT_AMOUNT:.+]] = shl i32 %exp, 2
36+
// CHECK: %[[ICMP1:.+]] = icmp ugt i32 %exp, 1073741823
37+
// CHECK: %[[ICMP2:.+]] = icmp ugt i32 %[[SHIFT_AMOUNT]], 31
38+
// CHECK: %[[OVERFLOW:.+]] = or i1 %[[ICMP1]], %[[ICMP2]]
39+
// CHECK: %[[POW:.+]] = shl nuw i32 1, %[[SHIFT_AMOUNT]]
40+
// CHECK: %[[RET:.+]] = select i1 %[[OVERFLOW]], i32 0, i32 %[[POW]]
41+
// CHECK: ret i32 %[[RET]]
42+
16u32.pow(exp)
43+
}

tests/ui/numbers-arithmetic/overflowing-pow-signed.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//@ run-fail
22
//@ regex-error-pattern: thread 'main'.*panicked
3-
//@ error-pattern: attempt to multiply with overflow
3+
//@ regex-error-pattern: attempt to (multiply|shift left) with overflow
44
//@ needs-subprocess
55
//@ compile-flags: -C debug-assertions
66

tests/ui/numbers-arithmetic/overflowing-pow-unsigned.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//@ run-fail
22
//@ regex-error-pattern: thread 'main'.*panicked
3-
//@ error-pattern: attempt to multiply with overflow
3+
//@ regex-error-pattern: attempt to (multiply|shift left) with overflow
44
//@ needs-subprocess
55
//@ compile-flags: -C debug-assertions
66

0 commit comments

Comments
 (0)