Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit c2a3ecd

Browse files
token-2022: prevent an already configured confidential account to be configured again (#3216)
* token-2022: prevent an already configured confidential account to be configured again * token-2022: add overwrite flag to init extension * token-2022: clippy * token-2022: update initialize mint for interest bearing mint * token-2022: confidential transfer mint init allow overwrite
1 parent 08d0592 commit c2a3ecd

File tree

8 files changed

+112
-62
lines changed

8 files changed

+112
-62
lines changed

token/program-2022-test/tests/confidential_transfer.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use {
1010
signer::keypair::Keypair, transaction::TransactionError, transport::TransportError,
1111
},
1212
spl_token_2022::{
13+
error::TokenError,
1314
extension::{
1415
confidential_transfer::{
1516
ConfidentialTransferAccount, ConfidentialTransferMint, EncryptedWithheldAmount,
@@ -331,6 +332,27 @@ async fn ct_configure_token_account() {
331332
.get_extension::<ConfidentialTransferAccount>()
332333
.unwrap();
333334
assert!(bool::from(&extension.approved));
335+
336+
// Configuring an already initialized account should produce an error
337+
let err = token
338+
.confidential_transfer_configure_token_account(
339+
&alice_meta.token_account,
340+
&alice,
341+
alice_meta.elgamal_keypair.public,
342+
alice_meta.ae_key.encrypt(0_u64),
343+
)
344+
.await
345+
.unwrap_err();
346+
347+
assert_eq!(
348+
err,
349+
TokenClientError::Client(Box::new(TransportError::TransactionError(
350+
TransactionError::InstructionError(
351+
0,
352+
InstructionError::Custom(TokenError::ExtensionAlreadyInitialized as u32),
353+
)
354+
)))
355+
);
334356
}
335357

336358
#[tokio::test]

token/program-2022/src/extension/confidential_transfer/processor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ fn process_initialize_mint(
5454
check_program_account(mint_info.owner)?;
5555
let mint_data = &mut mint_info.data.borrow_mut();
5656
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(mint_data)?;
57-
*mint.init_extension::<ConfidentialTransferMint>()? = *confidential_transfer_mint;
57+
*mint.init_extension::<ConfidentialTransferMint>(true)? = *confidential_transfer_mint;
5858

5959
Ok(())
6060
}
@@ -125,7 +125,7 @@ fn process_configure_account(
125125
// Note: The caller is expected to use the `Reallocate` instruction to ensure there is
126126
// sufficient room in their token account for the new `ConfidentialTransferAccount` extension
127127
let mut confidential_transfer_account =
128-
token_account.init_extension::<ConfidentialTransferAccount>()?;
128+
token_account.init_extension::<ConfidentialTransferAccount>(false)?;
129129
confidential_transfer_account.approved = confidential_transfer_mint.auto_approve_new_accounts;
130130
confidential_transfer_account.encryption_pubkey = *encryption_pubkey;
131131

token/program-2022/src/extension/default_account_state/processor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn process_initialize_default_account_state(
3636
let mint_account_info = next_account_info(account_info_iter)?;
3737
let mut mint_data = mint_account_info.data.borrow_mut();
3838
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
39-
let extension = mint.init_extension::<DefaultAccountState>()?;
39+
let extension = mint.init_extension::<DefaultAccountState>(true)?;
4040
extension.state = state.into();
4141
Ok(())
4242
}

token/program-2022/src/extension/interest_bearing_mint/processor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn process_initialize(
3636
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
3737

3838
let clock = Clock::get()?;
39-
let extension = mint.init_extension::<InterestBearingConfig>()?;
39+
let extension = mint.init_extension::<InterestBearingConfig>(true)?;
4040
extension.rate_authority = *rate_authority;
4141
extension.initialization_timestamp = clock.unix_timestamp.into();
4242
extension.last_update_timestamp = clock.unix_timestamp.into();

token/program-2022/src/extension/memo_transfer/processor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ fn process_enable_required_memo_transfers(
4040
let extension = if let Ok(extension) = account.get_extension_mut::<MemoTransfer>() {
4141
extension
4242
} else {
43-
account.init_extension::<MemoTransfer>()?
43+
account.init_extension::<MemoTransfer>(true)?
4444
};
4545
extension.require_incoming_transfer_memos = true.into();
4646
Ok(())
@@ -69,7 +69,7 @@ fn process_diasble_required_memo_transfers(
6969
let extension = if let Ok(extension) = account.get_extension_mut::<MemoTransfer>() {
7070
extension
7171
} else {
72-
account.init_extension::<MemoTransfer>()?
72+
account.init_extension::<MemoTransfer>(true)?
7373
};
7474
extension.require_incoming_transfer_memos = false.into();
7575
Ok(())

token/program-2022/src/extension/mod.rs

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -413,46 +413,23 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
413413
}
414414
}
415415

416-
fn init_or_get_extension<V: Extension>(&mut self, init: bool) -> Result<&mut V, ProgramError> {
416+
/// Unpack a portion of the TLV data as the desired type that allows modifying the type
417+
pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
417418
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
418419
return Err(ProgramError::InvalidAccountData);
419420
}
420421
let TlvIndices {
421422
type_start,
422423
length_start,
423424
value_start,
424-
} = get_extension_indices::<V>(self.tlv_data, init)?;
425+
} = get_extension_indices::<V>(self.tlv_data, false)?;
425426

426427
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
427428
return Err(ProgramError::InvalidAccountData);
428429
}
429-
if init {
430-
// write extension type
431-
let extension_type_array: [u8; 2] = V::TYPE.into();
432-
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
433-
extension_type_ref.copy_from_slice(&extension_type_array);
434-
// write length
435-
let length_ref =
436-
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
437-
// maybe this becomes smarter later for dynamically sized extensions
438-
let length = pod_get_packed_len::<V>();
439-
*length_ref = Length::try_from(length).unwrap();
440-
441-
let value_end = value_start.saturating_add(length);
442-
let extension_ref =
443-
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
444-
*extension_ref = V::default();
445-
Ok(extension_ref)
446-
} else {
447-
let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
448-
let value_end = value_start.saturating_add(usize::from(*length));
449-
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])
450-
}
451-
}
452-
453-
/// Unpack a portion of the TLV data as the desired type that allows modifying the type
454-
pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
455-
self.init_or_get_extension(false)
430+
let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
431+
let value_end = value_start.saturating_add(usize::from(*length));
432+
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])
456433
}
457434

458435
/// Unpack a portion of the TLV data as the desired type
@@ -480,9 +457,48 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
480457
}
481458

482459
/// Packs the default extension data into an open slot if not already found in the
483-
/// data buffer, otherwise overwrites the existing extension with the default state
484-
pub fn init_extension<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
485-
self.init_or_get_extension(true)
460+
/// data buffer. If extension is already found in the buffer, it overwrites the existing
461+
/// extension with the default state if `overwrite` is set. If extension found, but
462+
/// `overwrite` is not set, it returns error.
463+
pub fn init_extension<V: Extension>(
464+
&mut self,
465+
overwrite: bool,
466+
) -> Result<&mut V, ProgramError> {
467+
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
468+
return Err(ProgramError::InvalidAccountData);
469+
}
470+
let TlvIndices {
471+
type_start,
472+
length_start,
473+
value_start,
474+
} = get_extension_indices::<V>(self.tlv_data, true)?;
475+
476+
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
477+
return Err(ProgramError::InvalidAccountData);
478+
}
479+
let extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
480+
481+
if extension_type == ExtensionType::Uninitialized || overwrite {
482+
// write extension type
483+
let extension_type_array: [u8; 2] = V::TYPE.into();
484+
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
485+
extension_type_ref.copy_from_slice(&extension_type_array);
486+
// write length
487+
let length_ref =
488+
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
489+
// maybe this becomes smarter later for dynamically sized extensions
490+
let length = pod_get_packed_len::<V>();
491+
*length_ref = Length::try_from(length).unwrap();
492+
493+
let value_end = value_start.saturating_add(length);
494+
let extension_ref =
495+
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
496+
*extension_ref = V::default();
497+
Ok(extension_ref)
498+
} else {
499+
// extension is already initialized, but no overwrite permission
500+
Err(TokenError::ExtensionAlreadyInitialized.into())
501+
}
486502
}
487503

488504
/// If `extension_type` is an Account-associated ExtensionType that requires initialization on
@@ -498,14 +514,14 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
498514
}
499515
match extension_type {
500516
ExtensionType::TransferFeeAmount => {
501-
self.init_extension::<TransferFeeAmount>().map(|_| ())
517+
self.init_extension::<TransferFeeAmount>(true).map(|_| ())
502518
}
503519
// ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
504520
// on InitializeAccount
505521
ExtensionType::ConfidentialTransferAccount => Ok(()),
506522
#[cfg(test)]
507523
ExtensionType::AccountPaddingTest => {
508-
self.init_extension::<AccountPaddingTest>().map(|_| ())
524+
self.init_extension::<AccountPaddingTest>(true).map(|_| ())
509525
}
510526
_ => unreachable!(),
511527
}
@@ -932,19 +948,27 @@ mod test {
932948
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
933949
// fail init account extension
934950
assert_eq!(
935-
state.init_extension::<TransferFeeAmount>(),
951+
state.init_extension::<TransferFeeAmount>(true),
936952
Err(ProgramError::InvalidAccountData),
937953
);
938954

939955
// success write extension
940956
let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
941-
let extension = state.init_extension::<MintCloseAuthority>().unwrap();
957+
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
942958
extension.close_authority = close_authority;
943959
assert_eq!(
944960
&state.get_extension_types().unwrap(),
945961
&[ExtensionType::MintCloseAuthority]
946962
);
947963

964+
// fail init extension when already initialized
965+
assert_eq!(
966+
state.init_extension::<MintCloseAuthority>(false),
967+
Err(ProgramError::Custom(
968+
TokenError::ExtensionAlreadyInitialized as u32
969+
))
970+
);
971+
948972
// fail unpack as account, a mint extension was written
949973
assert_eq!(
950974
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
@@ -1030,7 +1054,7 @@ mod test {
10301054
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
10311055
// init one more extension
10321056
let mint_transfer_fee = test_transfer_fee_config();
1033-
let new_extension = state.init_extension::<TransferFeeConfig>().unwrap();
1057+
let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
10341058
new_extension.transfer_fee_config_authority =
10351059
mint_transfer_fee.transfer_fee_config_authority;
10361060
new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
@@ -1063,7 +1087,7 @@ mod test {
10631087
// fail to init one more extension that does not fit
10641088
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
10651089
assert_eq!(
1066-
state.init_extension::<MintPaddingTest>(),
1090+
state.init_extension::<MintPaddingTest>(true),
10671091
Err(ProgramError::InvalidAccountData),
10681092
);
10691093
}
@@ -1079,11 +1103,11 @@ mod test {
10791103
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
10801104
// write extensions
10811105
let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1082-
let extension = state.init_extension::<MintCloseAuthority>().unwrap();
1106+
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
10831107
extension.close_authority = close_authority;
10841108

10851109
let mint_transfer_fee = test_transfer_fee_config();
1086-
let extension = state.init_extension::<TransferFeeConfig>().unwrap();
1110+
let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
10871111
extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
10881112
extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
10891113
extension.withheld_amount = mint_transfer_fee.withheld_amount;
@@ -1115,15 +1139,15 @@ mod test {
11151139

11161140
// write extensions in a different order
11171141
let mint_transfer_fee = test_transfer_fee_config();
1118-
let extension = state.init_extension::<TransferFeeConfig>().unwrap();
1142+
let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
11191143
extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
11201144
extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
11211145
extension.withheld_amount = mint_transfer_fee.withheld_amount;
11221146
extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
11231147
extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
11241148

11251149
let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1126-
let extension = state.init_extension::<MintCloseAuthority>().unwrap();
1150+
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
11271151
extension.close_authority = close_authority;
11281152

11291153
assert_eq!(
@@ -1169,7 +1193,7 @@ mod test {
11691193
state.init_account_type().unwrap();
11701194

11711195
// write padding
1172-
let extension = state.init_extension::<MintPaddingTest>().unwrap();
1196+
let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
11731197
extension.padding1 = [1; 128];
11741198
extension.padding2 = [1; 48];
11751199
extension.padding3 = [1; 9];
@@ -1206,12 +1230,12 @@ mod test {
12061230
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
12071231
// fail init mint extension
12081232
assert_eq!(
1209-
state.init_extension::<TransferFeeConfig>(),
1233+
state.init_extension::<TransferFeeConfig>(true),
12101234
Err(ProgramError::InvalidAccountData),
12111235
);
12121236
// success write extension
12131237
let withheld_amount = PodU64::from(u64::MAX);
1214-
let extension = state.init_extension::<TransferFeeAmount>().unwrap();
1238+
let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
12151239
extension.withheld_amount = withheld_amount;
12161240

12171241
assert_eq!(
@@ -1305,7 +1329,7 @@ mod test {
13051329
state.init_account_type().unwrap();
13061330

13071331
// write padding
1308-
let extension = state.init_extension::<AccountPaddingTest>().unwrap();
1332+
let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
13091333
extension.0.padding1 = [2; 128];
13101334
extension.0.padding2 = [2; 48];
13111335
extension.0.padding3 = [2; 9];
@@ -1341,7 +1365,7 @@ mod test {
13411365
let mut state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
13421366
assert_eq!(state.base, TEST_ACCOUNT);
13431367
assert_eq!(state.account_type[0], AccountType::Account as u8);
1344-
state.init_extension::<ImmutableOwner>().unwrap(); // just confirming initialization works
1368+
state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
13451369

13461370
// account with buffer big enough for AccountType only
13471371
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
@@ -1384,7 +1408,7 @@ mod test {
13841408
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
13851409
assert_eq!(state.base, TEST_MINT);
13861410
assert_eq!(state.account_type[0], AccountType::Mint as u8);
1387-
state.init_extension::<MintCloseAuthority>().unwrap();
1411+
state.init_extension::<MintCloseAuthority>(true).unwrap();
13881412

13891413
// mint with buffer big enough for AccountType only
13901414
let mut buffer = TEST_MINT_SLICE.to_vec();
@@ -1499,7 +1523,7 @@ mod test {
14991523

15001524
// fail init extension
15011525
assert_eq!(
1502-
state.init_extension::<TransferFeeConfig>(),
1526+
state.init_extension::<TransferFeeConfig>(true),
15031527
Err(ProgramError::InvalidAccountData),
15041528
);
15051529

@@ -1514,7 +1538,7 @@ mod test {
15141538
state.base = TEST_MINT;
15151539
state.pack_base();
15161540
state.init_account_type().unwrap();
1517-
let extension = state.init_extension::<MintPaddingTest>().unwrap();
1541+
let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
15181542
assert_eq!(extension.padding1, [1; 128]);
15191543
assert_eq!(extension.padding2, [2; 48]);
15201544
assert_eq!(extension.padding3, [3; 9]);
@@ -1526,7 +1550,9 @@ mod test {
15261550
ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority]);
15271551
let mut buffer = vec![0; mint_size - 1];
15281552
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1529-
let err = state.init_extension::<MintCloseAuthority>().unwrap_err();
1553+
let err = state
1554+
.init_extension::<MintCloseAuthority>(true)
1555+
.unwrap_err();
15301556
assert_eq!(err, ProgramError::InvalidAccountData);
15311557

15321558
state.tlv_data[0] = 3;
@@ -1556,7 +1582,7 @@ mod test {
15561582
state.base = TEST_ACCOUNT;
15571583
state.pack_base();
15581584
state.init_account_type().unwrap();
1559-
state.init_extension::<ImmutableOwner>().unwrap();
1585+
state.init_extension::<ImmutableOwner>(true).unwrap();
15601586

15611587
assert_eq!(
15621588
get_first_extension_type(state.tlv_data).unwrap(),

token/program-2022/src/extension/transfer_fee/processor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn process_initialize_transfer_fee_config(
3636

3737
let mut mint_data = mint_account_info.data.borrow_mut();
3838
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
39-
let extension = mint.init_extension::<TransferFeeConfig>()?;
39+
let extension = mint.init_extension::<TransferFeeConfig>(true)?;
4040
extension.transfer_fee_config_authority = transfer_fee_config_authority.try_into()?;
4141
extension.withdraw_withheld_authority = withdraw_withheld_authority.try_into()?;
4242
extension.withheld_amount = 0u64.into();

0 commit comments

Comments
 (0)