Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 27b0df1

Browse files
token-2022: add StateWithExtensionsMut::unpack_after_realloc (#2859)
* Add unpack_after_realloc * Dedupe with internal fn
1 parent 3863914 commit 27b0df1

File tree

1 file changed

+153
-1
lines changed
  • token/program-2022/src/extension

1 file changed

+153
-1
lines changed

token/program-2022/src/extension/mod.rs

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,28 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
326326
///
327327
/// Fails if the base state is not initialized.
328328
pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
329+
Self::_unpack(input, false)
330+
}
331+
/// Unpack base state, leaving the extension data as a mutable slice
332+
/// Checks the account_type, and initializes it if Uninitialized
333+
///
334+
/// Fails if the base state is not initialized.
335+
pub fn unpack_after_realloc(input: &'data mut [u8]) -> Result<Self, ProgramError> {
336+
Self::_unpack(input, true)
337+
}
338+
339+
fn _unpack(input: &'data mut [u8], init_account_type: bool) -> Result<Self, ProgramError> {
329340
check_min_len_and_not_multisig(input, S::LEN)?;
330341
let (base_data, rest) = input.split_at_mut(S::LEN);
331342
let base = S::unpack(base_data)?;
332343
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
333344
// type_and_tlv_indices() checks that returned indexes are within range
334-
let account_type = AccountType::try_from(rest[account_type_index])
345+
let mut account_type = AccountType::try_from(rest[account_type_index])
335346
.map_err(|_| ProgramError::InvalidAccountData)?;
347+
if init_account_type && account_type == AccountType::Uninitialized {
348+
rest[account_type_index] = S::ACCOUNT_TYPE.into();
349+
account_type = S::ACCOUNT_TYPE;
350+
}
336351
check_account_type::<S>(account_type)?;
337352
let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
338353
Ok(Self {
@@ -1270,6 +1285,143 @@ mod test {
12701285
assert_eq!(expect, buffer);
12711286
}
12721287

1288+
#[test]
1289+
fn test_unpack_after_realloc() {
1290+
// account
1291+
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1292+
let state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1293+
let realloc = state
1294+
.realloc_needed(ExtensionType::ImmutableOwner)
1295+
.unwrap()
1296+
.unwrap();
1297+
drop(state);
1298+
buffer.append(&mut vec![0; realloc]);
1299+
let mut state =
1300+
StateWithExtensionsMut::<Account>::unpack_after_realloc(&mut buffer).unwrap();
1301+
assert_eq!(state.base, TEST_ACCOUNT);
1302+
assert_eq!(state.account_type[0], AccountType::Account as u8);
1303+
state.init_extension::<ImmutableOwner>().unwrap();
1304+
1305+
// account with AccountType
1306+
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1307+
buffer.append(&mut vec![2, 0]);
1308+
let state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1309+
assert_eq!(state.base, TEST_ACCOUNT);
1310+
assert_eq!(state.account_type[0], AccountType::Account as u8);
1311+
let realloc = state
1312+
.realloc_needed(ExtensionType::ImmutableOwner)
1313+
.unwrap()
1314+
.unwrap();
1315+
drop(state);
1316+
buffer.append(&mut vec![0; realloc]);
1317+
let mut state =
1318+
StateWithExtensionsMut::<Account>::unpack_after_realloc(&mut buffer).unwrap();
1319+
assert_eq!(state.base, TEST_ACCOUNT);
1320+
assert_eq!(state.account_type[0], AccountType::Account as u8);
1321+
state.init_extension::<ImmutableOwner>().unwrap();
1322+
1323+
// account with wrong AccountType
1324+
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1325+
buffer.append(&mut vec![1, 0]);
1326+
let err = StateWithExtensionsMut::<Account>::unpack_after_realloc(&mut buffer).unwrap_err();
1327+
assert_eq!(err, ProgramError::InvalidAccountData);
1328+
1329+
// account with pre-existing extension
1330+
let account_size =
1331+
ExtensionType::get_account_len::<Account>(&[ExtensionType::ImmutableOwner]);
1332+
let mut buffer = vec![0; account_size];
1333+
let mut state =
1334+
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1335+
state.base = TEST_ACCOUNT;
1336+
state.pack_base();
1337+
state.init_extension::<ImmutableOwner>().unwrap();
1338+
state.init_account_type().unwrap();
1339+
let realloc = state
1340+
.realloc_needed(ExtensionType::TransferFeeAmount)
1341+
.unwrap()
1342+
.unwrap();
1343+
drop(state);
1344+
buffer.append(&mut vec![0; realloc]);
1345+
let mut state =
1346+
StateWithExtensionsMut::<Account>::unpack_after_realloc(&mut buffer).unwrap();
1347+
assert_eq!(state.base, TEST_ACCOUNT);
1348+
assert_eq!(state.account_type[0], AccountType::Account as u8);
1349+
state.init_extension::<TransferFeeAmount>().unwrap();
1350+
assert_eq!(
1351+
state.get_extension_types().unwrap(),
1352+
vec![
1353+
ExtensionType::ImmutableOwner,
1354+
ExtensionType::TransferFeeAmount
1355+
]
1356+
);
1357+
1358+
// mint
1359+
let mut buffer = TEST_MINT_SLICE.to_vec();
1360+
let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1361+
let realloc = state
1362+
.realloc_needed(ExtensionType::MintCloseAuthority)
1363+
.unwrap()
1364+
.unwrap();
1365+
drop(state);
1366+
buffer.append(&mut vec![0; realloc]);
1367+
let mut state = StateWithExtensionsMut::<Mint>::unpack_after_realloc(&mut buffer).unwrap();
1368+
assert_eq!(state.base, TEST_MINT);
1369+
assert_eq!(state.account_type[0], AccountType::Mint as u8);
1370+
state.init_extension::<MintCloseAuthority>().unwrap();
1371+
1372+
// mint with AccountType
1373+
let mut buffer = TEST_MINT_SLICE.to_vec();
1374+
buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1375+
buffer.append(&mut vec![1, 0]);
1376+
let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1377+
assert_eq!(state.base, TEST_MINT);
1378+
assert_eq!(state.account_type[0], AccountType::Mint as u8);
1379+
let realloc = state
1380+
.realloc_needed(ExtensionType::MintCloseAuthority)
1381+
.unwrap()
1382+
.unwrap();
1383+
drop(state);
1384+
buffer.append(&mut vec![0; realloc]);
1385+
let mut state = StateWithExtensionsMut::<Mint>::unpack_after_realloc(&mut buffer).unwrap();
1386+
assert_eq!(state.base, TEST_MINT);
1387+
assert_eq!(state.account_type[0], AccountType::Mint as u8);
1388+
state.init_extension::<MintCloseAuthority>().unwrap();
1389+
1390+
// mint with wrong AccountType
1391+
let mut buffer = TEST_MINT_SLICE.to_vec();
1392+
buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1393+
buffer.append(&mut vec![2, 0]);
1394+
let err = StateWithExtensionsMut::<Mint>::unpack_after_realloc(&mut buffer).unwrap_err();
1395+
assert_eq!(err, ProgramError::InvalidAccountData);
1396+
1397+
// mint with pre-existing extension
1398+
let mut buffer = MINT_WITH_EXTENSION.to_vec();
1399+
let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1400+
assert_eq!(state.base, TEST_MINT);
1401+
assert_eq!(state.account_type[0], AccountType::Mint as u8);
1402+
assert_eq!(
1403+
state.get_extension_types().unwrap(),
1404+
vec![ExtensionType::MintCloseAuthority]
1405+
);
1406+
let realloc = state
1407+
.realloc_needed(ExtensionType::TransferFeeConfig)
1408+
.unwrap()
1409+
.unwrap();
1410+
drop(state);
1411+
buffer.append(&mut vec![0; realloc]);
1412+
let mut state = StateWithExtensionsMut::<Mint>::unpack_after_realloc(&mut buffer).unwrap();
1413+
assert_eq!(state.base, TEST_MINT);
1414+
assert_eq!(state.account_type[0], AccountType::Mint as u8);
1415+
state.init_extension::<TransferFeeConfig>().unwrap();
1416+
assert_eq!(
1417+
state.get_extension_types().unwrap(),
1418+
vec![
1419+
ExtensionType::MintCloseAuthority,
1420+
ExtensionType::TransferFeeConfig
1421+
]
1422+
);
1423+
}
1424+
12731425
#[test]
12741426
fn test_get_required_init_account_extensions() {
12751427
// Some mint extensions with no required account extensions

0 commit comments

Comments
 (0)