Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 448e75f

Browse files
token-2022: fixup realloc_needed (#2856)
* Add failing test cases and fix buggy method * Add extension initializations to ensure buffers are long enough * Add extra checks * Move multisig check lower * Add missing test case and fix
1 parent 115c3c4 commit 448e75f

File tree

1 file changed

+116
-20
lines changed
  • token/program-2022/src/extension

1 file changed

+116
-20
lines changed

token/program-2022/src/extension/mod.rs

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -505,13 +505,21 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
505505
&self,
506506
new_extension: ExtensionType,
507507
) -> Result<Option<usize>, ProgramError> {
508-
let current_extensions = self.get_extension_types()?;
509-
let needed_tlv_len = ExtensionType::get_total_tlv_len(&current_extensions);
510-
let new_needed_tlv_len = needed_tlv_len.saturating_add(new_extension.get_type_len());
508+
let mut extensions = self.get_extension_types()?;
509+
if !extensions.contains(&new_extension) {
510+
extensions.push(new_extension);
511+
}
512+
let new_needed_tlv_len = ExtensionType::get_total_tlv_len(&extensions);
511513
if self.tlv_data.len() >= new_needed_tlv_len {
512514
Ok(None)
513515
} else {
514-
Ok(Some(new_needed_tlv_len - self.tlv_data.len())) // arithmetic safe because of if clause
516+
let mut diff = new_needed_tlv_len - self.tlv_data.len(); // arithmetic safe because of if clause
517+
if self.account_type.is_empty() {
518+
diff = diff
519+
.saturating_add(size_of::<AccountType>())
520+
.saturating_add(BASE_ACCOUNT_LENGTH.saturating_sub(S::LEN));
521+
}
522+
Ok(Some(diff))
515523
}
516524
}
517525
}
@@ -611,7 +619,16 @@ impl ExtensionType {
611619

612620
/// Get the TLV length for a set of ExtensionTypes
613621
fn get_total_tlv_len(extension_types: &[Self]) -> usize {
614-
extension_types.iter().map(|e| e.get_tlv_len()).sum()
622+
let tlv_len: usize = extension_types.iter().map(|e| e.get_tlv_len()).sum();
623+
if tlv_len
624+
== Multisig::LEN
625+
.saturating_sub(BASE_ACCOUNT_LENGTH)
626+
.saturating_sub(size_of::<AccountType>())
627+
{
628+
tlv_len.saturating_add(size_of::<ExtensionType>())
629+
} else {
630+
tlv_len
631+
}
615632
}
616633

617634
/// Get the required account data length for the given ExtensionTypes
@@ -620,14 +637,9 @@ impl ExtensionType {
620637
S::LEN
621638
} else {
622639
let extension_size = Self::get_total_tlv_len(extension_types);
623-
let account_size = extension_size
640+
extension_size
624641
.saturating_add(BASE_ACCOUNT_LENGTH)
625-
.saturating_add(size_of::<AccountType>());
626-
if account_size == Multisig::LEN {
627-
account_size.saturating_add(size_of::<ExtensionType>())
628-
} else {
629-
account_size
630-
}
642+
.saturating_add(size_of::<AccountType>())
631643
}
632644
}
633645

@@ -1373,6 +1385,60 @@ mod test {
13731385

13741386
#[test]
13751387
fn test_realloc_needed() {
1388+
// buffer exact size of base-state account
1389+
let account_size = ExtensionType::get_account_len::<Account>(&[]);
1390+
let mut buffer = vec![0; account_size];
1391+
let mut state =
1392+
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1393+
state.base = TEST_ACCOUNT;
1394+
state.pack_base();
1395+
state.init_account_type().unwrap();
1396+
let realloc = state.realloc_needed(ExtensionType::ImmutableOwner).unwrap();
1397+
assert_eq!(
1398+
realloc,
1399+
Some(ExtensionType::ImmutableOwner.get_tlv_len() + size_of::<AccountType>())
1400+
);
1401+
assert_eq!(
1402+
account_size + realloc.unwrap(),
1403+
ExtensionType::get_account_len::<Account>(&[ExtensionType::ImmutableOwner])
1404+
);
1405+
let mut buffer = vec![0; account_size + realloc.unwrap()];
1406+
let mut state =
1407+
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1408+
state.base = TEST_ACCOUNT;
1409+
state.pack_base();
1410+
state.init_account_type().unwrap();
1411+
state.init_extension::<ImmutableOwner>().unwrap();
1412+
1413+
// buffer exact size of base-state mint
1414+
let account_size = ExtensionType::get_account_len::<Mint>(&[]);
1415+
let mut buffer = vec![0; account_size];
1416+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1417+
state.base = TEST_MINT;
1418+
state.pack_base();
1419+
state.init_account_type().unwrap();
1420+
let realloc = state
1421+
.realloc_needed(ExtensionType::MintCloseAuthority)
1422+
.unwrap();
1423+
assert_eq!(
1424+
realloc,
1425+
Some(
1426+
ExtensionType::MintCloseAuthority.get_tlv_len()
1427+
+ size_of::<AccountType>()
1428+
+ (Account::LEN - Mint::LEN)
1429+
)
1430+
);
1431+
assert_eq!(
1432+
account_size + realloc.unwrap(),
1433+
ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority])
1434+
);
1435+
let mut buffer = vec![0; account_size + realloc.unwrap()];
1436+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1437+
state.base = TEST_MINT;
1438+
state.pack_base();
1439+
state.init_account_type().unwrap();
1440+
state.init_extension::<MintCloseAuthority>().unwrap();
1441+
13761442
// buffer exact size of existing extension
13771443
let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::TransferFeeConfig]);
13781444
let mut buffer = vec![0; mint_size];
@@ -1387,12 +1453,27 @@ mod test {
13871453
None
13881454
);
13891455
state.init_extension::<TransferFeeConfig>().unwrap();
1456+
let realloc = state
1457+
.realloc_needed(ExtensionType::MintCloseAuthority)
1458+
.unwrap();
13901459
assert_eq!(
1391-
state
1392-
.realloc_needed(ExtensionType::MintCloseAuthority)
1393-
.unwrap(),
1394-
Some(ExtensionType::MintCloseAuthority.get_type_len())
1460+
realloc,
1461+
Some(ExtensionType::MintCloseAuthority.get_tlv_len())
13951462
);
1463+
assert_eq!(
1464+
mint_size + realloc.unwrap(),
1465+
ExtensionType::get_account_len::<Account>(&[
1466+
ExtensionType::TransferFeeConfig,
1467+
ExtensionType::MintCloseAuthority
1468+
])
1469+
);
1470+
let mut buffer = vec![0; mint_size + realloc.unwrap()];
1471+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1472+
state.base = TEST_MINT;
1473+
state.pack_base();
1474+
state.init_account_type().unwrap();
1475+
state.init_extension::<TransferFeeConfig>().unwrap();
1476+
state.init_extension::<MintCloseAuthority>().unwrap();
13961477

13971478
// buffer with multisig len
13981479
let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
@@ -1408,12 +1489,27 @@ mod test {
14081489
None
14091490
);
14101491
state.init_extension::<MintPaddingTest>().unwrap();
1492+
let realloc = state
1493+
.realloc_needed(ExtensionType::MintCloseAuthority)
1494+
.unwrap();
14111495
assert_eq!(
1412-
state
1413-
.realloc_needed(ExtensionType::MintCloseAuthority)
1414-
.unwrap(),
1415-
Some(ExtensionType::MintCloseAuthority.get_type_len() - size_of::<ExtensionType>())
1496+
realloc,
1497+
Some(ExtensionType::MintCloseAuthority.get_tlv_len() - size_of::<ExtensionType>())
1498+
);
1499+
assert_eq!(
1500+
mint_size + realloc.unwrap(),
1501+
ExtensionType::get_account_len::<Account>(&[
1502+
ExtensionType::MintPaddingTest,
1503+
ExtensionType::MintCloseAuthority
1504+
])
14161505
);
1506+
let mut buffer = vec![0; mint_size + realloc.unwrap()];
1507+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1508+
state.base = TEST_MINT;
1509+
state.pack_base();
1510+
state.init_account_type().unwrap();
1511+
state.init_extension::<MintPaddingTest>().unwrap();
1512+
state.init_extension::<MintCloseAuthority>().unwrap();
14171513

14181514
// huge buffer
14191515
let mut buffer = vec![0; u16::MAX.into()];

0 commit comments

Comments
 (0)