Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 8f325dc

Browse files
authored
math: Improve sqrt using bit-wise operations (#1562)
* math: Improve sqrt guess using bit-wise operations * Run fmt and bump up instruction for failed test * Bump up compute cost from CI failure * Update CI version of toolchain * Address feedback
1 parent 0e2b080 commit 8f325dc

File tree

3 files changed

+33
-27
lines changed

3 files changed

+33
-27
lines changed

ci/solana-version.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if [[ -n $SOLANA_VERSION ]]; then
1515
solana_version="$SOLANA_VERSION"
1616
else
17-
solana_version=v1.5.15
17+
solana_version=v1.6.2
1818
fi
1919

2020
export solana_version="$solana_version"

libraries/math/src/approximations.rs

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,42 @@
11
//! Approximation calculations
22
33
use {
4-
num_traits::{CheckedAdd, CheckedDiv, One, Zero},
5-
std::cmp::Eq,
4+
num_traits::{CheckedShl, CheckedShr, PrimInt},
5+
std::cmp::Ordering,
66
};
77

8-
const SQRT_ITERATIONS: u8 = 50;
9-
10-
/// Perform square root
11-
pub fn sqrt<T: CheckedAdd + CheckedDiv + One + Zero + Eq + Copy>(radicand: T) -> Option<T> {
12-
if radicand == T::zero() {
13-
return Some(T::zero());
8+
/// Calculate square root of the given number
9+
///
10+
/// Code lovingly adapted from the excellent work at:
11+
/// https://github.com/derekdreery/integer-sqrt-rs
12+
///
13+
/// The algorithm is based on the implementation in:
14+
/// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
15+
pub fn sqrt<T: PrimInt + CheckedShl + CheckedShr>(radicand: T) -> Option<T> {
16+
match radicand.cmp(&T::zero()) {
17+
Ordering::Less => return None, // fail for less than 0
18+
Ordering::Equal => return Some(T::zero()), // do nothing for 0
19+
_ => {}
1420
}
15-
// A good initial guess is the average of the interval that contains the
16-
// input number. For all numbers, that will be between 1 and the given number.
17-
let one = T::one();
18-
let two = one.checked_add(&one)?;
19-
let mut guess = radicand.checked_div(&two)?.checked_add(&one)?;
20-
let mut last_guess = guess;
21-
for _ in 0..SQRT_ITERATIONS {
22-
// x_k+1 = (x_k + radicand / x_k) / 2
23-
guess = last_guess
24-
.checked_add(&radicand.checked_div(&last_guess)?)?
25-
.checked_div(&two)?;
26-
if last_guess == guess {
27-
break;
21+
22+
// Compute bit, the largest power of 4 <= n
23+
let max_shift: u32 = T::zero().leading_zeros() - 1;
24+
let shift: u32 = (max_shift - radicand.leading_zeros()) & !1;
25+
let mut bit = T::one().checked_shl(shift)?;
26+
27+
let mut n = radicand;
28+
let mut result = T::zero();
29+
while bit != T::zero() {
30+
let result_with_bit = result.checked_add(&bit)?;
31+
if n >= result_with_bit {
32+
n = n.checked_sub(&result_with_bit)?;
33+
result = result.checked_shr(1)?.checked_add(&bit)?;
2834
} else {
29-
last_guess = guess;
35+
result = result.checked_shr(1)?;
3036
}
37+
bit = bit.checked_shr(2)?;
3138
}
32-
Some(guess)
39+
Some(result)
3340
}
3441

3542
#[cfg(test)]

libraries/math/tests/instruction_count.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async fn test_sqrt_u128() {
6262
let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction));
6363

6464
// Dial down the BPF compute budget to detect if the operation gets bloated in the future
65-
pc.set_bpf_compute_max_units(5_500);
65+
pc.set_bpf_compute_max_units(4_000);
6666

6767
let (mut banks_client, payer, recent_blockhash) = pc.start().await;
6868

@@ -78,8 +78,7 @@ async fn test_sqrt_u128() {
7878
async fn test_sqrt_u128_max() {
7979
let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction));
8080

81-
// This is pretty big too!
82-
pc.set_bpf_compute_max_units(90_000);
81+
pc.set_bpf_compute_max_units(6_000);
8382

8483
let (mut banks_client, payer, recent_blockhash) = pc.start().await;
8584

0 commit comments

Comments
 (0)