12
12
program_error:: ProgramError ,
13
13
program_pack:: { IsInitialized , Pack , Sealed } ,
14
14
} ,
15
- spl_math:: { precise_number:: PreciseNumber , uint:: U256 } ,
15
+ spl_math:: { checked_ceil_div :: CheckedCeilDiv , precise_number:: PreciseNumber , uint:: U256 } ,
16
16
std:: convert:: TryFrom ,
17
17
} ;
18
18
@@ -133,14 +133,14 @@ fn compute_new_destination_amount(
133
133
let b = new_source_amount. checked_add ( d_val. checked_div ( leverage) ?) ?;
134
134
135
135
// Solve for y by approximating: y**2 + b*y = c
136
- let mut y_prev: U256 ;
137
136
let mut y = d_val;
138
137
for _ in 0 ..ITERATIONS {
139
- y_prev = y;
140
- y = ( checked_u8_power ( & y, 2 ) ?. checked_add ( c) ?)
141
- . checked_div ( checked_u8_mul ( & y, 2 ) ?. checked_add ( b) ?. checked_sub ( d_val) ?) ?;
142
- if y == y_prev {
138
+ let ( y_new, _) = ( checked_u8_power ( & y, 2 ) ?. checked_add ( c) ?)
139
+ . checked_ceil_div ( checked_u8_mul ( & y, 2 ) ?. checked_add ( b) ?. checked_sub ( d_val) ?) ?;
140
+ if y_new == y {
143
141
break ;
142
+ } else {
143
+ y = y_new;
144
144
}
145
145
}
146
146
u128:: try_from ( y) . ok ( )
@@ -155,6 +155,12 @@ impl CurveCalculator for StableCurve {
155
155
swap_destination_amount : u128 ,
156
156
_trade_direction : TradeDirection ,
157
157
) -> Option < SwapWithoutFeesResult > {
158
+ if source_amount == 0 {
159
+ return Some ( SwapWithoutFeesResult {
160
+ source_amount_swapped : 0 ,
161
+ destination_amount_swapped : 0 ,
162
+ } ) ;
163
+ }
158
164
let leverage = compute_a ( self . amp ) ?;
159
165
160
166
let new_source_amount = swap_source_amount. checked_add ( source_amount) ?;
@@ -409,9 +415,19 @@ mod tests {
409
415
check_pool_token_rate ( 5 , 501 , 2 , 10 , 1 , 101 ) ;
410
416
}
411
417
418
+ #[ test]
419
+ fn swap_zero ( ) {
420
+ let curve = StableCurve { amp : 100 } ;
421
+ let result = curve. swap_without_fees ( 0 , 100 , 1_000_000_000_000_000 , TradeDirection :: AtoB ) ;
422
+
423
+ let result = result. unwrap ( ) ;
424
+ assert_eq ! ( result. source_amount_swapped, 0 ) ;
425
+ assert_eq ! ( result. destination_amount_swapped, 0 ) ;
426
+ }
427
+
412
428
proptest ! {
413
429
#[ test]
414
- fn constant_product_swap_no_fee (
430
+ fn swap_no_fee (
415
431
swap_source_amount in 100 ..1_000_000_000_000_000_000u128 ,
416
432
swap_destination_amount in 100 ..1_000_000_000_000_000_000u128 ,
417
433
source_amount in 100 ..100_000_000_000u128 ,
@@ -440,7 +456,8 @@ mod tests {
440
456
let diff =
441
457
( sim_result as i128 - result. destination_amount_swapped as i128 ) . abs( ) ;
442
458
443
- let tolerance = std:: cmp:: max( 1 , sim_result as i128 / 1_000_000_000 ) ;
459
+ // tolerate a difference of 2 because of the ceiling during calculation
460
+ let tolerance = std:: cmp:: max( 2 , sim_result as i128 / 1_000_000_000 ) ;
444
461
445
462
assert!(
446
463
diff <= tolerance,
0 commit comments