Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 4825298

Browse files
F32 Normal CDF Instruction and Implementation (#3090)
* F32 Normal CDF Instruction and Implementation * Linting * Prop Test and Option Return Removal * Reverting Change to Cargo.lock
1 parent 9123a80 commit 4825298

File tree

6 files changed

+118
-1
lines changed

6 files changed

+118
-1
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libraries/math/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ uint = "0.9"
2424
proptest = "1.0.0"
2525
solana-program-test = "1.10.8"
2626
solana-sdk = "1.10.8"
27+
libm = "0.2.2"
2728

2829
[lib]
2930
crate-type = ["cdylib", "lib"]
3031

32+
3133
[package.metadata.docs.rs]
3234
targets = ["x86_64-unknown-linux-gnu"]

libraries/math/src/approximations.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,36 @@ pub fn sqrt<T: PrimInt + CheckedShl + CheckedShr>(radicand: T) -> Option<T> {
4141
Some(result)
4242
}
4343

44+
/// Calculate the normal cdf of the given number
45+
///
46+
/// The approximation is accurate to 3 digits
47+
///
48+
/// Code lovingly adapted from the excellent work at:
49+
///
50+
/// <https://www.hrpub.org/download/20140305/MS7-13401470.pdf>
51+
///
52+
/// The algorithm is based on the implementation in the paper above.
53+
#[inline(never)]
54+
pub fn f32_normal_cdf(argument: f32) -> f32 {
55+
const PI: f32 = std::f32::consts::PI;
56+
57+
let mod_argument = if argument < 0.0 {
58+
-1.0 * argument
59+
} else {
60+
argument
61+
};
62+
let tabulation_numerator: f32 =
63+
(1.0 / (1.0 * (2.0 * PI).sqrt())) * (-1.0 * (mod_argument * mod_argument) / 2.0).exp();
64+
let tabulation_denominator: f32 =
65+
0.226 + 0.64 * mod_argument + 0.33 * (mod_argument * mod_argument + 3.0).sqrt();
66+
let y: f32 = 1.0 - tabulation_numerator / tabulation_denominator;
67+
if argument < 0.0 {
68+
1.0 - y
69+
} else {
70+
y
71+
}
72+
}
73+
4474
#[cfg(test)]
4575
mod tests {
4676
use {super::*, proptest::prelude::*};
@@ -67,4 +97,27 @@ mod tests {
6797
check_square_root(a as u128);
6898
}
6999
}
100+
101+
fn check_normal_cdf_f32(argument: f32) {
102+
let result = f32_normal_cdf(argument);
103+
let check_result = 0.5 * (1.0 + libm::erff(argument / std::f32::consts::SQRT_2));
104+
let abs_difference: f32 = (result - check_result).abs();
105+
assert!(abs_difference <= 0.000_2);
106+
}
107+
108+
#[test]
109+
fn test_normal_cdf_f32_min_max() {
110+
let test_arguments: [f32; 2] = [f32::MIN, f32::MAX];
111+
for i in test_arguments.iter() {
112+
check_normal_cdf_f32(*i as f32)
113+
}
114+
}
115+
116+
proptest! {
117+
#[test]
118+
fn test_normal_cdf(a in -1000..1000) {
119+
120+
check_normal_cdf_f32((a as f32)*0.005);
121+
}
122+
}
70123
}

libraries/math/src/instruction.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ pub enum MathInstruction {
8989
argument: f32,
9090
},
9191

92+
/// The Normal CDF of a float
93+
///
94+
/// No accounts required for this instruction
95+
F32NormalCDF {
96+
/// The argument
97+
argument: f32,
98+
},
99+
92100
/// Don't do anything for comparison
93101
///
94102
/// No accounts required for this instruction
@@ -200,6 +208,17 @@ pub fn f32_natural_log(argument: f32) -> Instruction {
200208
}
201209
}
202210

211+
/// Create F32 Normal CDF instruction
212+
pub fn f32_normal_cdf(argument: f32) -> Instruction {
213+
Instruction {
214+
program_id: id(),
215+
accounts: vec![],
216+
data: MathInstruction::F32NormalCDF { argument }
217+
.try_to_vec()
218+
.unwrap(),
219+
}
220+
}
221+
203222
/// Create Noop instruction
204223
pub fn noop() -> Instruction {
205224
Instruction {
@@ -347,6 +366,19 @@ mod tests {
347366
assert_eq!(instruction.program_id, crate::id())
348367
}
349368

369+
#[test]
370+
fn test_f32_normal_cdf() {
371+
let instruction = f32_normal_cdf(f32::MAX);
372+
assert_eq!(0, instruction.accounts.len());
373+
assert_eq!(
374+
instruction.data,
375+
MathInstruction::F32NormalCDF { argument: f32::MAX }
376+
.try_to_vec()
377+
.unwrap()
378+
);
379+
assert_eq!(instruction.program_id, crate::id())
380+
}
381+
350382
#[test]
351383
fn test_noop() {
352384
let instruction = noop();

libraries/math/src/processor.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
//! Program state processor
22
33
use {
4-
crate::{approximations::sqrt, instruction::MathInstruction, precise_number::PreciseNumber},
4+
crate::{
5+
approximations::{f32_normal_cdf, sqrt},
6+
instruction::MathInstruction,
7+
precise_number::PreciseNumber,
8+
},
59
borsh::BorshDeserialize,
610
solana_program::{
711
account_info::AccountInfo, entrypoint::ProgramResult, log::sol_log_compute_units, msg,
@@ -33,11 +37,13 @@ fn f32_divide(dividend: f32, divisor: f32) -> f32 {
3337
dividend / divisor
3438
}
3539

40+
/// f32_exponentiate
3641
#[inline(never)]
3742
fn f32_exponentiate(base: f32, exponent: f32) -> f32 {
3843
base.powf(exponent)
3944
}
4045

46+
/// f32_natural_log
4147
#[inline(never)]
4248
fn f32_natural_log(argument: f32) -> f32 {
4349
argument.ln()
@@ -130,6 +136,14 @@ pub fn process_instruction(
130136
msg!("{}", result as u64);
131137
Ok(())
132138
}
139+
MathInstruction::F32NormalCDF { argument } => {
140+
msg!("Calculating f32 Normal CDF");
141+
sol_log_compute_units();
142+
let result = f32_normal_cdf(argument);
143+
sol_log_compute_units();
144+
msg!("{}", result as u64);
145+
Ok(())
146+
}
133147
MathInstruction::Noop => {
134148
msg!("Do nothing");
135149
msg!("{}", 0_u64);

libraries/math/tests/instruction_count.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,21 @@ async fn test_f32_natural_log() {
179179
banks_client.process_transaction(transaction).await.unwrap();
180180
}
181181

182+
#[tokio::test]
183+
async fn test_f32_normal_cdf() {
184+
let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction));
185+
186+
// Dial down the BPF compute budget to detect if the operation gets bloated in the future
187+
pc.set_compute_max_units(3_100);
188+
189+
let (mut banks_client, payer, recent_blockhash) = pc.start().await;
190+
191+
let mut transaction =
192+
Transaction::new_with_payer(&[instruction::f32_normal_cdf(0_f32)], Some(&payer.pubkey()));
193+
transaction.sign(&[&payer], recent_blockhash);
194+
banks_client.process_transaction(transaction).await.unwrap();
195+
}
196+
182197
#[tokio::test]
183198
async fn test_noop() {
184199
let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction));

0 commit comments

Comments
 (0)