@@ -37,8 +37,9 @@ use solana_sdk::{
37
37
use spl_associated_token_account:: get_associated_token_address_with_program_id;
38
38
use spl_token_2022:: {
39
39
extension:: {
40
- interest_bearing_mint:: InterestBearingConfig , memo_transfer:: MemoTransfer ,
41
- mint_close_authority:: MintCloseAuthority , ExtensionType , StateWithExtensionsOwned ,
40
+ cpi_guard:: CpiGuard , interest_bearing_mint:: InterestBearingConfig ,
41
+ memo_transfer:: MemoTransfer , mint_close_authority:: MintCloseAuthority , ExtensionType ,
42
+ StateWithExtensionsOwned ,
42
43
} ,
43
44
instruction:: * ,
44
45
state:: { Account , Mint } ,
@@ -131,6 +132,8 @@ pub enum CommandName {
131
132
SyncNative ,
132
133
EnableRequiredTransferMemos ,
133
134
DisableRequiredTransferMemos ,
135
+ EnableCpiGuard ,
136
+ DisableCpiGuard ,
134
137
}
135
138
impl fmt:: Display for CommandName {
136
139
fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
@@ -1842,8 +1845,7 @@ async fn command_sync_native(config: &Config<'_>, native_account_address: Pubkey
1842
1845
} )
1843
1846
}
1844
1847
1845
- // Both enable_required_transfer_mesos and disable_required_transfer_mesos
1846
- // Switches with enable_memos bool
1848
+ // both enables and disables required transfer memos, via enable_memos bool
1847
1849
async fn command_required_transfer_memos (
1848
1850
config : & Config < ' _ > ,
1849
1851
token_account_address : Pubkey ,
@@ -1852,7 +1854,7 @@ async fn command_required_transfer_memos(
1852
1854
enable_memos : bool ,
1853
1855
) -> CommandResult {
1854
1856
if config. sign_only {
1855
- panic ! ( "Config can not be sign only for enabling/disabling required transfer memos." ) ;
1857
+ panic ! ( "Config can not be sign- only for enabling/disabling required transfer memos." ) ;
1856
1858
}
1857
1859
1858
1860
let account = config. get_account_checked ( & token_account_address) . await ?;
@@ -1864,14 +1866,15 @@ async fn command_required_transfer_memos(
1864
1866
// Reallocation (if needed)
1865
1867
let mut existing_extensions: Vec < ExtensionType > = state_with_extension. get_extension_types ( ) ?;
1866
1868
if existing_extensions. contains ( & ExtensionType :: MemoTransfer ) {
1867
- let extension_data : bool = state_with_extension
1869
+ let extension_state = state_with_extension
1868
1870
. get_extension :: < MemoTransfer > ( ) ?
1869
1871
. require_incoming_transfer_memos
1870
1872
. into ( ) ;
1871
- if extension_data == enable_memos {
1873
+
1874
+ if extension_state == enable_memos {
1872
1875
return Ok ( format ! (
1873
- "Required memo transfer was already {}" ,
1874
- if extension_data {
1876
+ "Required transfer memos were already {}" ,
1877
+ if extension_state {
1875
1878
"enabled"
1876
1879
} else {
1877
1880
"disabled"
@@ -1914,6 +1917,78 @@ async fn command_required_transfer_memos(
1914
1917
} )
1915
1918
}
1916
1919
1920
+ // both enables and disables cpi guard, via enable_guard bool
1921
+ async fn command_cpi_guard (
1922
+ config : & Config < ' _ > ,
1923
+ token_account_address : Pubkey ,
1924
+ owner : Pubkey ,
1925
+ bulk_signers : BulkSigners ,
1926
+ enable_guard : bool ,
1927
+ ) -> CommandResult {
1928
+ if config. sign_only {
1929
+ panic ! ( "Config can not be sign-only for enabling/disabling required transfer memos." ) ;
1930
+ }
1931
+
1932
+ let account = config. get_account_checked ( & token_account_address) . await ?;
1933
+ let current_account_len = account. data . len ( ) ;
1934
+
1935
+ let state_with_extension = StateWithExtensionsOwned :: < Account > :: unpack ( account. data ) ?;
1936
+ let token = token_client_from_config ( config, & state_with_extension. base . mint , None ) ?;
1937
+
1938
+ // reallocation (if needed)
1939
+ let mut existing_extensions: Vec < ExtensionType > = state_with_extension. get_extension_types ( ) ?;
1940
+ if existing_extensions. contains ( & ExtensionType :: CpiGuard ) {
1941
+ let extension_state = state_with_extension
1942
+ . get_extension :: < CpiGuard > ( ) ?
1943
+ . lock_cpi
1944
+ . into ( ) ;
1945
+
1946
+ if extension_state == enable_guard {
1947
+ return Ok ( format ! (
1948
+ "CPI Guard was already {}" ,
1949
+ if extension_state {
1950
+ "enabled"
1951
+ } else {
1952
+ "disabled"
1953
+ }
1954
+ ) ) ;
1955
+ }
1956
+ } else {
1957
+ existing_extensions. push ( ExtensionType :: CpiGuard ) ;
1958
+ let required_account_len = ExtensionType :: get_account_len :: < Account > ( & existing_extensions) ;
1959
+ if required_account_len > current_account_len {
1960
+ token
1961
+ . reallocate (
1962
+ & token_account_address,
1963
+ & owner,
1964
+ & [ ExtensionType :: CpiGuard ] ,
1965
+ & bulk_signers,
1966
+ )
1967
+ . await ?;
1968
+ }
1969
+ }
1970
+
1971
+ let res = if enable_guard {
1972
+ token
1973
+ . enable_cpi_guard ( & token_account_address, & owner, & bulk_signers)
1974
+ . await
1975
+ } else {
1976
+ token
1977
+ . disable_cpi_guard ( & token_account_address, & owner, & bulk_signers)
1978
+ . await
1979
+ } ?;
1980
+
1981
+ let tx_return = finish_tx ( config, & res, false ) . await ?;
1982
+ Ok ( match tx_return {
1983
+ TransactionReturnData :: CliSignature ( signature) => {
1984
+ config. output_format . formatted_string ( & signature)
1985
+ }
1986
+ TransactionReturnData :: CliSignOnlyData ( sign_only_data) => {
1987
+ config. output_format . formatted_string ( & sign_only_data)
1988
+ }
1989
+ } )
1990
+ }
1991
+
1917
1992
struct SignOnlyNeedsFullMintSpec { }
1918
1993
impl offline:: ArgsConfig for SignOnlyNeedsFullMintSpec {
1919
1994
fn sign_only_arg < ' a , ' b > ( & self , arg : Arg < ' a , ' b > ) -> Arg < ' a , ' b > {
@@ -2902,7 +2977,7 @@ fn app<'a, 'b>(
2902
2977
. takes_value ( true )
2903
2978
. index ( 1 )
2904
2979
. required ( true )
2905
- . help ( "The address of the token account to enable required transfer memos" )
2980
+ . help ( "The address of the token account to require transfer memos for " )
2906
2981
)
2907
2982
. arg (
2908
2983
owner_address_arg ( )
@@ -2920,7 +2995,43 @@ fn app<'a, 'b>(
2920
2995
. takes_value ( true )
2921
2996
. index ( 1 )
2922
2997
. required ( true )
2923
- . help ( "The address of the token account to disable required transfer memos" ) ,
2998
+ . help ( "The address of the token account to stop requiring transfer memos for" ) ,
2999
+ )
3000
+ . arg (
3001
+ owner_address_arg ( )
3002
+ )
3003
+ . arg ( multisig_signer_arg ( ) )
3004
+ . nonce_args ( true )
3005
+ )
3006
+ . subcommand (
3007
+ SubCommand :: with_name ( CommandName :: EnableCpiGuard . into ( ) )
3008
+ . about ( "Enable CPI Guard for token account" )
3009
+ . arg (
3010
+ Arg :: with_name ( "account" )
3011
+ . validator ( is_valid_pubkey)
3012
+ . value_name ( "TOKEN_ACCOUNT_ADDRESS" )
3013
+ . takes_value ( true )
3014
+ . index ( 1 )
3015
+ . required ( true )
3016
+ . help ( "The address of the token account to enable CPI Guard for" )
3017
+ )
3018
+ . arg (
3019
+ owner_address_arg ( )
3020
+ )
3021
+ . arg ( multisig_signer_arg ( ) )
3022
+ . nonce_args ( true )
3023
+ )
3024
+ . subcommand (
3025
+ SubCommand :: with_name ( CommandName :: DisableCpiGuard . into ( ) )
3026
+ . about ( "Disable CPI Guard for token account" )
3027
+ . arg (
3028
+ Arg :: with_name ( "account" )
3029
+ . validator ( is_valid_pubkey)
3030
+ . value_name ( "TOKEN_ACCOUNT_ADDRESS" )
3031
+ . takes_value ( true )
3032
+ . index ( 1 )
3033
+ . required ( true )
3034
+ . help ( "The address of the token account to disable CPI Guard for" ) ,
2924
3035
)
2925
3036
. arg (
2926
3037
owner_address_arg ( )
@@ -3509,6 +3620,28 @@ async fn process_command<'a>(
3509
3620
config. pubkey_or_default ( arg_matches, "account" , & mut wallet_manager) ?;
3510
3621
command_required_transfer_memos ( config, token_account, owner, bulk_signers, false ) . await
3511
3622
}
3623
+ ( CommandName :: EnableCpiGuard , arg_matches) => {
3624
+ let ( owner_signer, owner) =
3625
+ config. signer_or_default ( arg_matches, "owner" , & mut wallet_manager) ;
3626
+ if !bulk_signers. contains ( & owner_signer) {
3627
+ bulk_signers. push ( owner_signer) ;
3628
+ }
3629
+ // Since account is required argument it will always be present
3630
+ let token_account =
3631
+ config. pubkey_or_default ( arg_matches, "account" , & mut wallet_manager) ?;
3632
+ command_cpi_guard ( config, token_account, owner, bulk_signers, true ) . await
3633
+ }
3634
+ ( CommandName :: DisableCpiGuard , arg_matches) => {
3635
+ let ( owner_signer, owner) =
3636
+ config. signer_or_default ( arg_matches, "owner" , & mut wallet_manager) ;
3637
+ if !bulk_signers. contains ( & owner_signer) {
3638
+ bulk_signers. push ( owner_signer) ;
3639
+ }
3640
+ // Since account is required argument it will always be present
3641
+ let token_account =
3642
+ config. pubkey_or_default ( arg_matches, "account" , & mut wallet_manager) ?;
3643
+ command_cpi_guard ( config, token_account, owner, bulk_signers, false ) . await
3644
+ }
3512
3645
}
3513
3646
}
3514
3647
@@ -5094,6 +5227,66 @@ mod tests {
5094
5227
assert ! ( !enabled) ;
5095
5228
}
5096
5229
5230
+ #[ tokio:: test]
5231
+ #[ serial]
5232
+ async fn cpi_guard ( ) {
5233
+ let ( test_validator, payer) = new_validator_for_test ( ) . await ;
5234
+ let program_id = spl_token_2022:: id ( ) ;
5235
+ let config = test_config_with_default_signer ( & test_validator, & payer, & program_id) ;
5236
+ let token = create_token ( & config, & payer) . await ;
5237
+ let token_account = create_associated_account ( & config, & payer, token) . await ;
5238
+
5239
+ // enable works
5240
+ process_test_command (
5241
+ & config,
5242
+ & payer,
5243
+ & [
5244
+ "spl-token" ,
5245
+ CommandName :: EnableCpiGuard . into ( ) ,
5246
+ & token_account. to_string ( ) ,
5247
+ ] ,
5248
+ )
5249
+ . await
5250
+ . unwrap ( ) ;
5251
+ let extensions = StateWithExtensionsOwned :: < Account > :: unpack (
5252
+ config
5253
+ . rpc_client
5254
+ . get_account ( & token_account)
5255
+ . await
5256
+ . unwrap ( )
5257
+ . data ,
5258
+ )
5259
+ . unwrap ( ) ;
5260
+ let cpi_guard = extensions. get_extension :: < CpiGuard > ( ) . unwrap ( ) ;
5261
+ let enabled: bool = cpi_guard. lock_cpi . into ( ) ;
5262
+ assert ! ( enabled) ;
5263
+
5264
+ // disable works
5265
+ process_test_command (
5266
+ & config,
5267
+ & payer,
5268
+ & [
5269
+ "spl-token" ,
5270
+ CommandName :: DisableCpiGuard . into ( ) ,
5271
+ & token_account. to_string ( ) ,
5272
+ ] ,
5273
+ )
5274
+ . await
5275
+ . unwrap ( ) ;
5276
+ let extensions = StateWithExtensionsOwned :: < Account > :: unpack (
5277
+ config
5278
+ . rpc_client
5279
+ . get_account ( & token_account)
5280
+ . await
5281
+ . unwrap ( )
5282
+ . data ,
5283
+ )
5284
+ . unwrap ( ) ;
5285
+ let cpi_guard = extensions. get_extension :: < CpiGuard > ( ) . unwrap ( ) ;
5286
+ let enabled: bool = cpi_guard. lock_cpi . into ( ) ;
5287
+ assert ! ( !enabled) ;
5288
+ }
5289
+
5097
5290
#[ tokio:: test]
5098
5291
#[ serial]
5099
5292
async fn immutable_accounts ( ) {
0 commit comments