Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit f9e6f66

Browse files
Add extension realloc helper (#2821)
1 parent eaaed0d commit f9e6f66

File tree

1 file changed

+83
-0
lines changed
  • token/program-2022/src/extension

1 file changed

+83
-0
lines changed

token/program-2022/src/extension/mod.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,24 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
491491
fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
492492
get_first_extension_type(self.tlv_data)
493493
}
494+
495+
/// Compares the length of an extension with the currently used TLV buffer to determine if
496+
/// reallocation is needed. If so, returns Some(v), where v is the difference between current
497+
/// space and needed.
498+
#[allow(dead_code)]
499+
pub(crate) fn realloc_needed(
500+
&self,
501+
new_extension: ExtensionType,
502+
) -> Result<Option<usize>, ProgramError> {
503+
let current_extensions = self.get_extension_types()?;
504+
let needed_tlv_len = ExtensionType::get_total_tlv_len(&current_extensions);
505+
let new_needed_tlv_len = needed_tlv_len.saturating_add(new_extension.get_type_len());
506+
if self.tlv_data.len() >= new_needed_tlv_len {
507+
Ok(None)
508+
} else {
509+
Ok(Some(new_needed_tlv_len - self.tlv_data.len())) // arithmetic safe because of if clause
510+
}
511+
}
494512
}
495513

496514
/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig` types
@@ -1344,4 +1362,69 @@ mod test {
13441362

13451363
assert_eq!(state.get_extension_types().unwrap(), vec![]);
13461364
}
1365+
1366+
#[test]
1367+
fn test_realloc_needed() {
1368+
// buffer exact size of existing extension
1369+
let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::TransferFeeConfig]);
1370+
let mut buffer = vec![0; mint_size];
1371+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1372+
state.base = TEST_MINT;
1373+
state.pack_base();
1374+
state.init_account_type().unwrap();
1375+
assert_eq!(
1376+
state
1377+
.realloc_needed(ExtensionType::TransferFeeConfig)
1378+
.unwrap(),
1379+
None
1380+
);
1381+
state.init_extension::<TransferFeeConfig>().unwrap();
1382+
assert_eq!(
1383+
state
1384+
.realloc_needed(ExtensionType::MintCloseAuthority)
1385+
.unwrap(),
1386+
Some(ExtensionType::MintCloseAuthority.get_type_len())
1387+
);
1388+
1389+
// buffer with multisig len
1390+
let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
1391+
let mut buffer = vec![0; mint_size];
1392+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1393+
state.base = TEST_MINT;
1394+
state.pack_base();
1395+
state.init_account_type().unwrap();
1396+
assert_eq!(
1397+
state
1398+
.realloc_needed(ExtensionType::MintPaddingTest)
1399+
.unwrap(),
1400+
None
1401+
);
1402+
state.init_extension::<MintPaddingTest>().unwrap();
1403+
assert_eq!(
1404+
state
1405+
.realloc_needed(ExtensionType::MintCloseAuthority)
1406+
.unwrap(),
1407+
Some(ExtensionType::MintCloseAuthority.get_type_len() - size_of::<ExtensionType>())
1408+
);
1409+
1410+
// huge buffer
1411+
let mut buffer = vec![0; u16::MAX.into()];
1412+
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1413+
state.base = TEST_MINT;
1414+
state.pack_base();
1415+
state.init_account_type().unwrap();
1416+
assert_eq!(
1417+
state
1418+
.realloc_needed(ExtensionType::TransferFeeConfig)
1419+
.unwrap(),
1420+
None
1421+
);
1422+
state.init_extension::<TransferFeeConfig>().unwrap();
1423+
assert_eq!(
1424+
state
1425+
.realloc_needed(ExtensionType::MintCloseAuthority)
1426+
.unwrap(),
1427+
None
1428+
);
1429+
}
13471430
}

0 commit comments

Comments
 (0)