@@ -533,8 +533,12 @@ impl Processor {
533
533
let token_a = Self :: unpack_token_account ( token_a_info, & token_swap. token_program_id ( ) ) ?;
534
534
let token_b = Self :: unpack_token_account ( token_b_info, & token_swap. token_program_id ( ) ) ?;
535
535
let pool_mint = Self :: unpack_mint ( pool_mint_info, & token_swap. token_program_id ( ) ) ?;
536
- let pool_token_amount = to_u128 ( pool_token_amount) ?;
537
- let pool_mint_supply = to_u128 ( pool_mint. supply ) ?;
536
+ let current_pool_mint_supply = to_u128 ( pool_mint. supply ) ?;
537
+ let ( pool_token_amount, pool_mint_supply) = if current_pool_mint_supply > 0 {
538
+ ( to_u128 ( pool_token_amount) ?, current_pool_mint_supply)
539
+ } else {
540
+ ( calculator. new_pool_supply ( ) , calculator. new_pool_supply ( ) )
541
+ } ;
538
542
539
543
let results = calculator
540
544
. pool_tokens_to_trading_tokens (
@@ -658,13 +662,15 @@ impl Processor {
658
662
)
659
663
. ok_or ( SwapError :: ZeroTradingTokens ) ?;
660
664
let token_a_amount = to_u64 ( results. token_a_amount ) ?;
665
+ let token_a_amount = std:: cmp:: min ( token_a. amount , token_a_amount) ;
661
666
if token_a_amount < minimum_token_a_amount {
662
667
return Err ( SwapError :: ExceededSlippage . into ( ) ) ;
663
668
}
664
669
if token_a_amount == 0 && token_a. amount != 0 {
665
670
return Err ( SwapError :: ZeroTradingTokens . into ( ) ) ;
666
671
}
667
672
let token_b_amount = to_u64 ( results. token_b_amount ) ?;
673
+ let token_b_amount = std:: cmp:: min ( token_b. amount , token_b_amount) ;
668
674
if token_b_amount < minimum_token_b_amount {
669
675
return Err ( SwapError :: ExceededSlippage . into ( ) ) ;
670
676
}
@@ -693,7 +699,6 @@ impl Processor {
693
699
to_u64 ( pool_token_amount) ?,
694
700
) ?;
695
701
696
- let token_a_amount = std:: cmp:: min ( token_a. amount , token_a_amount) ;
697
702
if token_a_amount > 0 {
698
703
Self :: token_transfer (
699
704
swap_info. key ,
@@ -705,7 +710,6 @@ impl Processor {
705
710
token_a_amount,
706
711
) ?;
707
712
}
708
- let token_b_amount = std:: cmp:: min ( token_b. amount , token_b_amount) ;
709
713
if token_b_amount > 0 {
710
714
Self :: token_transfer (
711
715
swap_info. key ,
@@ -775,19 +779,22 @@ impl Processor {
775
779
776
780
let pool_mint = Self :: unpack_mint ( pool_mint_info, & token_swap. token_program_id ( ) ) ?;
777
781
let pool_mint_supply = to_u128 ( pool_mint. supply ) ?;
778
-
779
- let pool_token_amount = token_swap
780
- . swap_curve ( )
781
- . trading_tokens_to_pool_tokens (
782
- to_u128 ( source_token_amount) ?,
783
- to_u128 ( swap_token_a. amount ) ?,
784
- to_u128 ( swap_token_b. amount ) ?,
785
- pool_mint_supply,
786
- trade_direction,
787
- RoundDirection :: Floor ,
788
- token_swap. fees ( ) ,
789
- )
790
- . ok_or ( SwapError :: ZeroTradingTokens ) ?;
782
+ let pool_token_amount = if pool_mint_supply > 0 {
783
+ token_swap
784
+ . swap_curve ( )
785
+ . trading_tokens_to_pool_tokens (
786
+ to_u128 ( source_token_amount) ?,
787
+ to_u128 ( swap_token_a. amount ) ?,
788
+ to_u128 ( swap_token_b. amount ) ?,
789
+ pool_mint_supply,
790
+ trade_direction,
791
+ RoundDirection :: Floor ,
792
+ token_swap. fees ( ) ,
793
+ )
794
+ . ok_or ( SwapError :: ZeroTradingTokens ) ?
795
+ } else {
796
+ token_swap. swap_curve ( ) . calculator . new_pool_supply ( )
797
+ } ;
791
798
792
799
let pool_token_amount = to_u64 ( pool_token_amount) ?;
793
800
if pool_token_amount < minimum_pool_token_amount {
@@ -6730,4 +6737,176 @@ mod tests {
6730
6737
spl_token:: state:: Account :: unpack ( & accounts. token_b_account . data ) . unwrap ( ) ;
6731
6738
assert_eq ! ( swap_token_b. amount, 0 ) ;
6732
6739
}
6740
+
6741
+ #[ test]
6742
+ fn test_withdraw_all_constant_price_curve ( ) {
6743
+ let trade_fee_numerator = 1 ;
6744
+ let trade_fee_denominator = 10 ;
6745
+ let owner_trade_fee_numerator = 1 ;
6746
+ let owner_trade_fee_denominator = 30 ;
6747
+ let owner_withdraw_fee_numerator = 0 ;
6748
+ let owner_withdraw_fee_denominator = 30 ;
6749
+ let host_fee_numerator = 10 ;
6750
+ let host_fee_denominator = 100 ;
6751
+
6752
+ // initialize "unbalanced", so that withdrawing all will have some issues
6753
+ // A: 1_000_000_000
6754
+ // B: 2_000_000_000 (1_000 * 2_000_000)
6755
+ let swap_token_a_amount = 1_000_000_000 ;
6756
+ let swap_token_b_amount = 1_000 ;
6757
+ let token_b_price = 2_000_000 ;
6758
+ let fees = Fees {
6759
+ trade_fee_numerator,
6760
+ trade_fee_denominator,
6761
+ owner_trade_fee_numerator,
6762
+ owner_trade_fee_denominator,
6763
+ owner_withdraw_fee_numerator,
6764
+ owner_withdraw_fee_denominator,
6765
+ host_fee_numerator,
6766
+ host_fee_denominator,
6767
+ } ;
6768
+
6769
+ let swap_curve = SwapCurve {
6770
+ curve_type : CurveType :: ConstantPrice ,
6771
+ calculator : Box :: new ( ConstantPriceCurve { token_b_price } ) ,
6772
+ } ;
6773
+ let total_pool = swap_curve. calculator . new_pool_supply ( ) ;
6774
+ let user_key = Pubkey :: new_unique ( ) ;
6775
+ let withdrawer_key = Pubkey :: new_unique ( ) ;
6776
+
6777
+ let mut accounts = SwapAccountInfo :: new (
6778
+ & user_key,
6779
+ fees,
6780
+ swap_curve,
6781
+ swap_token_a_amount,
6782
+ swap_token_b_amount,
6783
+ ) ;
6784
+
6785
+ accounts. initialize_swap ( ) . unwrap ( ) ;
6786
+
6787
+ let (
6788
+ token_a_key,
6789
+ mut token_a_account,
6790
+ token_b_key,
6791
+ mut token_b_account,
6792
+ _pool_key,
6793
+ _pool_account,
6794
+ ) = accounts. setup_token_accounts ( & user_key, & withdrawer_key, 0 , 0 , 0 ) ;
6795
+
6796
+ let pool_key = accounts. pool_token_key ;
6797
+ let mut pool_account = accounts. pool_token_account . clone ( ) ;
6798
+
6799
+ // WithdrawAllTokenTypes will not take all token A and B, since their
6800
+ // ratio is unbalanced. It will try to take 1_500_000_000 worth of
6801
+ // each token, which means 1_500_000_000 token A, and 750 token B.
6802
+ // With no slippage, this will leave 250 token B in the pool.
6803
+ assert_eq ! (
6804
+ Err ( SwapError :: ExceededSlippage . into( ) ) ,
6805
+ accounts. withdraw_all_token_types(
6806
+ & user_key,
6807
+ & pool_key,
6808
+ & mut pool_account,
6809
+ & token_a_key,
6810
+ & mut token_a_account,
6811
+ & token_b_key,
6812
+ & mut token_b_account,
6813
+ total_pool. try_into( ) . unwrap( ) ,
6814
+ swap_token_a_amount,
6815
+ swap_token_b_amount,
6816
+ )
6817
+ ) ;
6818
+
6819
+ accounts
6820
+ . withdraw_all_token_types (
6821
+ & user_key,
6822
+ & pool_key,
6823
+ & mut pool_account,
6824
+ & token_a_key,
6825
+ & mut token_a_account,
6826
+ & token_b_key,
6827
+ & mut token_b_account,
6828
+ total_pool. try_into ( ) . unwrap ( ) ,
6829
+ 0 ,
6830
+ 0 ,
6831
+ )
6832
+ . unwrap ( ) ;
6833
+
6834
+ let token_a = spl_token:: state:: Account :: unpack ( & token_a_account. data ) . unwrap ( ) ;
6835
+ assert_eq ! ( token_a. amount, swap_token_a_amount) ;
6836
+ let token_b = spl_token:: state:: Account :: unpack ( & token_b_account. data ) . unwrap ( ) ;
6837
+ assert_eq ! ( token_b. amount, 750 ) ;
6838
+ let swap_token_a =
6839
+ spl_token:: state:: Account :: unpack ( & accounts. token_a_account . data ) . unwrap ( ) ;
6840
+ assert_eq ! ( swap_token_a. amount, 0 ) ;
6841
+ let swap_token_b =
6842
+ spl_token:: state:: Account :: unpack ( & accounts. token_b_account . data ) . unwrap ( ) ;
6843
+ assert_eq ! ( swap_token_b. amount, 250 ) ;
6844
+
6845
+ // deposit now, not enough to cover the tokens already in there
6846
+ let token_b_amount = 10 ;
6847
+ let token_a_amount = token_b_amount * token_b_price;
6848
+ let (
6849
+ token_a_key,
6850
+ mut token_a_account,
6851
+ token_b_key,
6852
+ mut token_b_account,
6853
+ pool_key,
6854
+ mut pool_account,
6855
+ ) = accounts. setup_token_accounts (
6856
+ & user_key,
6857
+ & withdrawer_key,
6858
+ token_a_amount,
6859
+ token_b_amount,
6860
+ 0 ,
6861
+ ) ;
6862
+
6863
+ assert_eq ! (
6864
+ Err ( SwapError :: ExceededSlippage . into( ) ) ,
6865
+ accounts. deposit_all_token_types(
6866
+ & withdrawer_key,
6867
+ & token_a_key,
6868
+ & mut token_a_account,
6869
+ & token_b_key,
6870
+ & mut token_b_account,
6871
+ & pool_key,
6872
+ & mut pool_account,
6873
+ 1 , // doesn't matter
6874
+ token_a_amount,
6875
+ token_b_amount,
6876
+ )
6877
+ ) ;
6878
+
6879
+ // deposit enough tokens, success!
6880
+ let token_b_amount = 125 ;
6881
+ let token_a_amount = token_b_amount * token_b_price;
6882
+ let (
6883
+ token_a_key,
6884
+ mut token_a_account,
6885
+ token_b_key,
6886
+ mut token_b_account,
6887
+ pool_key,
6888
+ mut pool_account,
6889
+ ) = accounts. setup_token_accounts (
6890
+ & user_key,
6891
+ & withdrawer_key,
6892
+ token_a_amount,
6893
+ token_b_amount,
6894
+ 0 ,
6895
+ ) ;
6896
+
6897
+ accounts
6898
+ . deposit_all_token_types (
6899
+ & withdrawer_key,
6900
+ & token_a_key,
6901
+ & mut token_a_account,
6902
+ & token_b_key,
6903
+ & mut token_b_account,
6904
+ & pool_key,
6905
+ & mut pool_account,
6906
+ 1 , // doesn't matter
6907
+ token_a_amount,
6908
+ token_b_amount,
6909
+ )
6910
+ . unwrap ( ) ;
6911
+ }
6733
6912
}
0 commit comments