Skip to content

Commit e8a603e

Browse files
committed
optimize: pow when base is a power of two
if base == 2 ** k, then (2 ** k) ** n == 2 ** (k * n) == 1 << (k * n)
1 parent 759e81a commit e8a603e

File tree

5 files changed

+129
-42
lines changed

5 files changed

+129
-42
lines changed

library/core/src/num/uint_macros.rs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,6 +2078,17 @@ macro_rules! uint_impl {
20782078
let mut base = self;
20792079
let mut acc: Self = 1;
20802080

2081+
if intrinsics::is_val_statically_known(base) && base.is_power_of_two() {
2082+
// change of base:
2083+
// if base == 2 ** k, then
2084+
// (2 ** k) ** n
2085+
// == 2 ** (k * n)
2086+
// == 1 << (k * n)
2087+
let k = base.ilog2();
2088+
let shift = try_opt!(k.checked_mul(exp));
2089+
return (1 as Self).checked_shl(shift);
2090+
}
2091+
20812092
if intrinsics::is_val_statically_known(exp) {
20822093
while exp > 1 {
20832094
if (exp & 1) == 1 {
@@ -3246,6 +3257,19 @@ macro_rules! uint_impl {
32463257
let mut overflow = false;
32473258
let mut tmp_overflow;
32483259

3260+
if intrinsics::is_val_statically_known(base) && base.is_power_of_two() {
3261+
// change of base:
3262+
// if base == 2 ** k, then
3263+
// (2 ** k) ** n
3264+
// == 2 ** (k * n)
3265+
// == 1 << (k * n)
3266+
let k = base.ilog2();
3267+
let Some(shift) = k.checked_mul(exp) else {
3268+
return (0, true)
3269+
};
3270+
return ((1 as Self).unbounded_shl(shift), shift >= Self::BITS)
3271+
}
3272+
32493273
if intrinsics::is_val_statically_known(exp) {
32503274
while exp > 1 {
32513275
if (exp & 1) == 1 {
@@ -3295,12 +3319,25 @@ macro_rules! uint_impl {
32953319
#[inline]
32963320
#[rustc_inherit_overflow_checks]
32973321
pub const fn pow(self, mut exp: u32) -> Self {
3298-
if exp == 0 {
3299-
return 1;
3300-
}
33013322
let mut base = self;
33023323
let mut acc = 1;
33033324

3325+
if intrinsics::is_val_statically_known(base) && base.is_power_of_two() {
3326+
// change of base:
3327+
// if base == 2 ** k, then
3328+
// (2 ** k) ** n
3329+
// == 2 ** (k * n)
3330+
// == 1 << (k * n)
3331+
let k = base.ilog2();
3332+
// Panic on overflow if `-C overflow-checks` is enabled.
3333+
// Otherwise will be optimized out
3334+
let _overflow_check = (1 as Self) << (k * exp);
3335+
let Some(shift) = k.checked_mul(exp) else {
3336+
return 0
3337+
};
3338+
return (1 as Self).unbounded_shl(shift)
3339+
}
3340+
33043341
if intrinsics::is_val_statically_known(exp) {
33053342
while exp > 1 {
33063343
if (exp & 1) == 1 {
@@ -3317,6 +3354,10 @@ macro_rules! uint_impl {
33173354
return acc * base;
33183355
}
33193356

3357+
if exp == 0 {
3358+
return 1;
3359+
}
3360+
33203361
// This is faster than the above when the exponent is not known
33213362
// at compile time. We can't use the same code for the constant
33223363
// exponent case because LLVM is currently unable to unroll

library/coretests/tests/num/uint_macros.rs

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,19 @@ macro_rules! uint_module {
271271
assert_eq!(from_str::<$T>("x"), None);
272272
}
273273

274-
// Not const because `catch_unwind` is not const yet.
274+
// Not const because overflow always panics during const evaluation, and
275+
// `catch_unwind` is not const yet.
275276
#[test]
276-
fn strict_pow() {
277+
fn panicking_pow() {
277278
use std::panic::catch_unwind;
278279

279280
{
280281
const R: $T = 0;
282+
assert_eq!(R.pow(0), 1 as $T);
283+
assert_eq!(R.pow(1), 0 as $T);
284+
assert_eq!(R.pow(2), 0 as $T);
285+
assert_eq!(R.pow(128), 0 as $T);
286+
281287
assert_eq!(R.strict_pow(0), 1 as $T);
282288
assert_eq!(R.strict_pow(1), 0 as $T);
283289
assert_eq!(R.strict_pow(2), 0 as $T);
@@ -286,6 +292,11 @@ macro_rules! uint_module {
286292

287293
{
288294
const R: $T = 1;
295+
assert_eq!(R.pow(0), 1 as $T);
296+
assert_eq!(R.pow(1), 1 as $T);
297+
assert_eq!(R.pow(2), 1 as $T);
298+
assert_eq!(R.pow(128), 1 as $T);
299+
289300
assert_eq!(R.strict_pow(0), 1 as $T);
290301
assert_eq!(R.strict_pow(1), 1 as $T);
291302
assert_eq!(R.strict_pow(2), 1 as $T);
@@ -294,48 +305,24 @@ macro_rules! uint_module {
294305

295306
{
296307
const R: $T = 2;
308+
assert_eq!(R.pow(0), 1 as $T);
309+
assert_eq!(R.pow(1), 2 as $T);
310+
assert_eq!(R.pow(2), 4 as $T);
311+
assert_eq!(R.pow(128), 0 as $T);
312+
assert_eq!(R.pow(129), 0 as $T);
313+
297314
assert_eq!(R.strict_pow(0), 1 as $T);
298315
assert_eq!(R.strict_pow(1), 2 as $T);
299316
assert_eq!(R.strict_pow(2), 4 as $T);
300317
assert!(catch_unwind(|| R.strict_pow(128)).is_err());
301318
}
302319

320+
// If `k * exp` wraps around into a valid shift amount, `pow` should still return 0/panic
303321
{
304-
const R: $T = $T::MAX;
305-
assert_eq!(R.strict_pow(0), 1 as $T);
306-
assert_eq!(R.strict_pow(1), R as $T);
307-
assert!(catch_unwind(|| R.strict_pow(2)).is_err());
308-
assert!(catch_unwind(|| R.strict_pow(128)).is_err());
309-
}
310-
}
311-
312-
// Not const because overflow always panics during const evaluation, and
313-
// `catch_unwind` is not const yet.
314-
#[test]
315-
fn pow() {
316-
{
317-
const R: $T = 0;
318-
assert_eq!(R.pow(0), 1 as $T);
319-
assert_eq!(R.pow(1), 0 as $T);
320-
assert_eq!(R.pow(2), 0 as $T);
321-
assert_eq!(R.pow(128), 0 as $T);
322-
}
323-
324-
{
325-
const R: $T = 1;
326-
assert_eq!(R.pow(0), 1 as $T);
327-
assert_eq!(R.pow(1), 1 as $T);
328-
assert_eq!(R.pow(2), 1 as $T);
329-
assert_eq!(R.pow(128), 1 as $T);
330-
}
331-
332-
{
333-
const R: $T = 2;
334-
assert_eq!(R.pow(0), 1 as $T);
335-
assert_eq!(R.pow(1), 2 as $T);
336-
assert_eq!(R.pow(2), 4 as $T);
337-
assert_eq!(R.pow(128), 0 as $T);
338-
assert_eq!(R.pow(129), 0 as $T);
322+
const R: $T = 4;
323+
const HALF: u32 = u32::MAX / 2 + 1;
324+
assert_eq!(R.pow(HALF), 0 as $T);
325+
assert!(catch_unwind(|| R.strict_pow(HALF)).is_err());
339326
}
340327

341328
{
@@ -345,6 +332,11 @@ macro_rules! uint_module {
345332
assert_eq!(R.pow(2), 1 as $T);
346333
assert_eq!(R.pow(128), 1 as $T);
347334
assert_eq!(R.pow(129), R as $T);
335+
336+
assert_eq!(R.strict_pow(0), 1 as $T);
337+
assert_eq!(R.strict_pow(1), R as $T);
338+
assert!(catch_unwind(|| R.strict_pow(2)).is_err());
339+
assert!(catch_unwind(|| R.strict_pow(128)).is_err());
348340
}
349341
}
350342

@@ -443,6 +435,17 @@ macro_rules! uint_module {
443435
assert_eq_const_safe!($T: R.saturating_pow(129), $T::MAX);
444436
}
445437

438+
// overflow in the shift caclculation should result in the final
439+
// result being 0 rather than accidentally succeeding due to a
440+
// shift within the word size
441+
// ie `4 ** 0x8000_0000` should give 0 rather than 1 << 0
442+
{
443+
const R: $T = 4;
444+
const HALF: u32 = u32::MAX / 2 + 1;
445+
assert_eq_const_safe!($T: R.wrapping_pow(HALF), 0 as $T);
446+
assert_eq_const_safe!(($T, bool): R.overflowing_pow(HALF), (0 as $T, true));
447+
}
448+
446449
{
447450
const R: $T = $T::MAX;
448451
assert_eq_const_safe!($T: R.wrapping_pow(0), 1 as $T);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//@ compile-flags: -Copt-level=3
2+
// Test that `pow` can use a faster implementation when `base` is a
3+
// known power of two
4+
5+
#![crate_type = "lib"]
6+
7+
// CHECK-LABEL: @pow2
8+
#[no_mangle]
9+
pub fn pow2(exp: u32) -> u32 {
10+
// CHECK: %[[OVERFLOW:.+]] = icmp ugt i32 %exp, 31
11+
// CHECK: %[[POW:.+]] = shl nuw i32 1, %exp
12+
// CHECK: %[[RET:.+]] = select i1 %[[OVERFLOW]], i32 0, i32 %[[POW]]
13+
// CHECK: ret i32 %[[RET]]
14+
2u32.pow(exp)
15+
}
16+
17+
// 4 ** n == 2 ** (2 * n) == 1 << (2 * n)
18+
// CHECK-LABEL: @pow4
19+
#[no_mangle]
20+
pub fn pow4(exp: u32) -> u32 {
21+
// CHECK: %[[SHIFT_AMOUNT:.+]] = shl i32 %exp, 1
22+
// CHECK: %[[ICMP1:.+]] = icmp slt i32 %exp, 0
23+
// CHECK: %[[ICMP2:.+]] = icmp ugt i32 %[[SHIFT_AMOUNT]], 31
24+
// CHECK: %[[OVERFLOW:.+]] = or i1 %[[ICMP1]], %[[ICMP2]]
25+
// CHECK: %[[POW:.+]] = shl nuw i32 1, %[[SHIFT_AMOUNT]]
26+
// CHECK: %[[RET:.+]] = select i1 %[[OVERFLOW]], i32 0, i32 %[[POW]]
27+
// CHECK: ret i32 %[[RET]]
28+
4u32.pow(exp)
29+
}
30+
31+
// 16 ** n == 2 ** (4 * n) == 1 << (4 * n)
32+
// CHECK-LABEL: @pow16
33+
#[no_mangle]
34+
pub fn pow16(exp: u32) -> u32 {
35+
// CHECK: %[[SHIFT_AMOUNT:.+]] = shl i32 %exp, 2
36+
// CHECK: %[[ICMP1:.+]] = icmp ugt i32 %exp, 1073741823
37+
// CHECK: %[[ICMP2:.+]] = icmp ugt i32 %[[SHIFT_AMOUNT]], 31
38+
// CHECK: %[[OVERFLOW:.+]] = or i1 %[[ICMP1]], %[[ICMP2]]
39+
// CHECK: %[[POW:.+]] = shl nuw i32 1, %[[SHIFT_AMOUNT]]
40+
// CHECK: %[[RET:.+]] = select i1 %[[OVERFLOW]], i32 0, i32 %[[POW]]
41+
// CHECK: ret i32 %[[RET]]
42+
16u32.pow(exp)
43+
}

tests/ui/numbers-arithmetic/overflowing-pow-signed.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//@ run-fail
22
//@ regex-error-pattern: thread 'main'.*panicked
3-
//@ error-pattern: attempt to multiply with overflow
3+
//@ regex-error-pattern: attempt to (multiply|shift left) with overflow
44
//@ needs-subprocess
55
//@ compile-flags: -C debug-assertions
66

tests/ui/numbers-arithmetic/overflowing-pow-unsigned.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//@ run-fail
22
//@ regex-error-pattern: thread 'main'.*panicked
3-
//@ error-pattern: attempt to multiply with overflow
3+
//@ regex-error-pattern: attempt to (multiply|shift left) with overflow
44
//@ needs-subprocess
55
//@ compile-flags: -C debug-assertions
66

0 commit comments

Comments
 (0)