Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit d557474

Browse files
authored
token-swap: Ceiling stable curve division (#2942)
1 parent 1316124 commit d557474

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

token-swap/program/src/curve/stable.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use {
1212
program_error::ProgramError,
1313
program_pack::{IsInitialized, Pack, Sealed},
1414
},
15-
spl_math::{precise_number::PreciseNumber, uint::U256},
15+
spl_math::{checked_ceil_div::CheckedCeilDiv, precise_number::PreciseNumber, uint::U256},
1616
std::convert::TryFrom,
1717
};
1818

@@ -133,14 +133,14 @@ fn compute_new_destination_amount(
133133
let b = new_source_amount.checked_add(d_val.checked_div(leverage)?)?;
134134

135135
// Solve for y by approximating: y**2 + b*y = c
136-
let mut y_prev: U256;
137136
let mut y = d_val;
138137
for _ in 0..ITERATIONS {
139-
y_prev = y;
140-
y = (checked_u8_power(&y, 2)?.checked_add(c)?)
141-
.checked_div(checked_u8_mul(&y, 2)?.checked_add(b)?.checked_sub(d_val)?)?;
142-
if y == y_prev {
138+
let (y_new, _) = (checked_u8_power(&y, 2)?.checked_add(c)?)
139+
.checked_ceil_div(checked_u8_mul(&y, 2)?.checked_add(b)?.checked_sub(d_val)?)?;
140+
if y_new == y {
143141
break;
142+
} else {
143+
y = y_new;
144144
}
145145
}
146146
u128::try_from(y).ok()
@@ -155,6 +155,12 @@ impl CurveCalculator for StableCurve {
155155
swap_destination_amount: u128,
156156
_trade_direction: TradeDirection,
157157
) -> Option<SwapWithoutFeesResult> {
158+
if source_amount == 0 {
159+
return Some(SwapWithoutFeesResult {
160+
source_amount_swapped: 0,
161+
destination_amount_swapped: 0,
162+
});
163+
}
158164
let leverage = compute_a(self.amp)?;
159165

160166
let new_source_amount = swap_source_amount.checked_add(source_amount)?;
@@ -409,9 +415,19 @@ mod tests {
409415
check_pool_token_rate(5, 501, 2, 10, 1, 101);
410416
}
411417

418+
#[test]
419+
fn swap_zero() {
420+
let curve = StableCurve { amp: 100 };
421+
let result = curve.swap_without_fees(0, 100, 1_000_000_000_000_000, TradeDirection::AtoB);
422+
423+
let result = result.unwrap();
424+
assert_eq!(result.source_amount_swapped, 0);
425+
assert_eq!(result.destination_amount_swapped, 0);
426+
}
427+
412428
proptest! {
413429
#[test]
414-
fn constant_product_swap_no_fee(
430+
fn swap_no_fee(
415431
swap_source_amount in 100..1_000_000_000_000_000_000u128,
416432
swap_destination_amount in 100..1_000_000_000_000_000_000u128,
417433
source_amount in 100..100_000_000_000u128,
@@ -440,7 +456,8 @@ mod tests {
440456
let diff =
441457
(sim_result as i128 - result.destination_amount_swapped as i128).abs();
442458

443-
let tolerance = std::cmp::max(1, sim_result as i128 / 1_000_000_000);
459+
// tolerate a difference of 2 because of the ceiling during calculation
460+
let tolerance = std::cmp::max(2, sim_result as i128 / 1_000_000_000);
444461

445462
assert!(
446463
diff <= tolerance,

0 commit comments

Comments
 (0)