Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 84182ce

Browse files
authored
stake-pool: Truncate on withdrawal calculation (#3804)
1 parent c5fcbef commit 84182ce

File tree

1 file changed

+77
-5
lines changed

1 file changed

+77
-5
lines changed

stake-pool/program/src/state.rs

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ use {
1818
pubkey::{Pubkey, PUBKEY_BYTES},
1919
stake::state::Lockup,
2020
},
21-
spl_math::checked_ceil_div::CheckedCeilDiv,
2221
spl_token::state::{Account, AccountState},
2322
std::{borrow::Borrow, convert::TryFrom, fmt, matches},
2423
};
@@ -172,16 +171,15 @@ impl StakePool {
172171
/// calculate lamports amount on withdrawal
173172
#[inline]
174173
pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
175-
// `checked_ceil_div` returns `None` for a 0 quotient result, but in this
174+
// `checked_div` returns `None` for a 0 quotient result, but in this
176175
// case, a return of 0 is valid for small amounts of pool tokens. So
177176
// we check for that separately
178177
let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
179178
let denominator = self.pool_token_supply as u128;
180179
if numerator < denominator || denominator == 0 {
181180
Some(0)
182181
} else {
183-
let (quotient, _) = numerator.checked_ceil_div(denominator)?;
184-
u64::try_from(quotient).ok()
182+
u64::try_from(numerator.checked_div(denominator)?).ok()
185183
}
186184
}
187185

@@ -1033,7 +1031,7 @@ mod test {
10331031
let fee_lamports = stake_pool
10341032
.calc_lamports_withdraw_amount(pool_token_fee)
10351033
.unwrap();
1036-
assert_eq!(fee_lamports, LAMPORTS_PER_SOL);
1034+
assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); // off-by-one due to truncation
10371035
}
10381036

10391037
#[test]
@@ -1148,6 +1146,80 @@ mod test {
11481146
stake_pool.pool_token_supply += deposit_result;
11491147
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
11501148
assert!(withdraw_result <= deposit_stake);
1149+
1150+
// also test splitting the withdrawal in two operations
1151+
if deposit_result >= 2 {
1152+
let first_half_deposit = deposit_result / 2;
1153+
let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1154+
stake_pool.total_lamports -= first_withdraw_result;
1155+
stake_pool.pool_token_supply -= first_half_deposit;
1156+
let second_half_deposit = deposit_result - first_half_deposit; // do the whole thing
1157+
let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1158+
assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1159+
}
11511160
}
11521161
}
1162+
1163+
#[test]
1164+
fn specific_split_withdrawal() {
1165+
let total_lamports = 1_100_000_000_000;
1166+
let pool_token_supply = 1_000_000_000_000;
1167+
let deposit_stake = 3;
1168+
let mut stake_pool = StakePool {
1169+
total_lamports,
1170+
pool_token_supply,
1171+
..StakePool::default()
1172+
};
1173+
let deposit_result = stake_pool
1174+
.calc_pool_tokens_for_deposit(deposit_stake)
1175+
.unwrap();
1176+
assert!(deposit_result > 0);
1177+
stake_pool.total_lamports += deposit_stake;
1178+
stake_pool.pool_token_supply += deposit_result;
1179+
let withdraw_result = stake_pool
1180+
.calc_lamports_withdraw_amount(deposit_result / 2)
1181+
.unwrap();
1182+
assert!(withdraw_result * 2 <= deposit_stake);
1183+
}
1184+
1185+
#[test]
1186+
fn withdraw_all() {
1187+
let total_lamports = 1_100_000_000_000;
1188+
let pool_token_supply = 1_000_000_000_000;
1189+
let mut stake_pool = StakePool {
1190+
total_lamports,
1191+
pool_token_supply,
1192+
..StakePool::default()
1193+
};
1194+
// take everything out at once
1195+
let withdraw_result = stake_pool
1196+
.calc_lamports_withdraw_amount(pool_token_supply)
1197+
.unwrap();
1198+
assert_eq!(stake_pool.total_lamports, withdraw_result);
1199+
1200+
// take out 1, then the rest
1201+
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1202+
stake_pool.total_lamports -= withdraw_result;
1203+
stake_pool.pool_token_supply -= 1;
1204+
let withdraw_result = stake_pool
1205+
.calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1206+
.unwrap();
1207+
assert_eq!(stake_pool.total_lamports, withdraw_result);
1208+
1209+
// take out all except 1, then the rest
1210+
let mut stake_pool = StakePool {
1211+
total_lamports,
1212+
pool_token_supply,
1213+
..StakePool::default()
1214+
};
1215+
let withdraw_result = stake_pool
1216+
.calc_lamports_withdraw_amount(pool_token_supply - 1)
1217+
.unwrap();
1218+
stake_pool.total_lamports -= withdraw_result;
1219+
stake_pool.pool_token_supply = 1;
1220+
assert_ne!(stake_pool.total_lamports, 0);
1221+
1222+
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1223+
assert_eq!(stake_pool.total_lamports, withdraw_result);
1224+
}
11531225
}

0 commit comments

Comments
 (0)