@@ -326,13 +326,28 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
326
326
///
327
327
/// Fails if the base state is not initialized.
328
328
pub fn unpack ( input : & ' data mut [ u8 ] ) -> Result < Self , ProgramError > {
329
+ Self :: _unpack ( input, false )
330
+ }
331
+ /// Unpack base state, leaving the extension data as a mutable slice
332
+ /// Checks the account_type, and initializes it if Uninitialized
333
+ ///
334
+ /// Fails if the base state is not initialized.
335
+ pub fn unpack_after_realloc ( input : & ' data mut [ u8 ] ) -> Result < Self , ProgramError > {
336
+ Self :: _unpack ( input, true )
337
+ }
338
+
339
+ fn _unpack ( input : & ' data mut [ u8 ] , init_account_type : bool ) -> Result < Self , ProgramError > {
329
340
check_min_len_and_not_multisig ( input, S :: LEN ) ?;
330
341
let ( base_data, rest) = input. split_at_mut ( S :: LEN ) ;
331
342
let base = S :: unpack ( base_data) ?;
332
343
if let Some ( ( account_type_index, tlv_start_index) ) = type_and_tlv_indices :: < S > ( rest) ? {
333
344
// type_and_tlv_indices() checks that returned indexes are within range
334
- let account_type = AccountType :: try_from ( rest[ account_type_index] )
345
+ let mut account_type = AccountType :: try_from ( rest[ account_type_index] )
335
346
. map_err ( |_| ProgramError :: InvalidAccountData ) ?;
347
+ if init_account_type && account_type == AccountType :: Uninitialized {
348
+ rest[ account_type_index] = S :: ACCOUNT_TYPE . into ( ) ;
349
+ account_type = S :: ACCOUNT_TYPE ;
350
+ }
336
351
check_account_type :: < S > ( account_type) ?;
337
352
let ( account_type, tlv_data) = rest. split_at_mut ( tlv_start_index) ;
338
353
Ok ( Self {
@@ -1270,6 +1285,143 @@ mod test {
1270
1285
assert_eq ! ( expect, buffer) ;
1271
1286
}
1272
1287
1288
+ #[ test]
1289
+ fn test_unpack_after_realloc ( ) {
1290
+ // account
1291
+ let mut buffer = TEST_ACCOUNT_SLICE . to_vec ( ) ;
1292
+ let state = StateWithExtensionsMut :: < Account > :: unpack ( & mut buffer) . unwrap ( ) ;
1293
+ let realloc = state
1294
+ . realloc_needed ( ExtensionType :: ImmutableOwner )
1295
+ . unwrap ( )
1296
+ . unwrap ( ) ;
1297
+ drop ( state) ;
1298
+ buffer. append ( & mut vec ! [ 0 ; realloc] ) ;
1299
+ let mut state =
1300
+ StateWithExtensionsMut :: < Account > :: unpack_after_realloc ( & mut buffer) . unwrap ( ) ;
1301
+ assert_eq ! ( state. base, TEST_ACCOUNT ) ;
1302
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Account as u8 ) ;
1303
+ state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ;
1304
+
1305
+ // account with AccountType
1306
+ let mut buffer = TEST_ACCOUNT_SLICE . to_vec ( ) ;
1307
+ buffer. append ( & mut vec ! [ 2 , 0 ] ) ;
1308
+ let state = StateWithExtensionsMut :: < Account > :: unpack ( & mut buffer) . unwrap ( ) ;
1309
+ assert_eq ! ( state. base, TEST_ACCOUNT ) ;
1310
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Account as u8 ) ;
1311
+ let realloc = state
1312
+ . realloc_needed ( ExtensionType :: ImmutableOwner )
1313
+ . unwrap ( )
1314
+ . unwrap ( ) ;
1315
+ drop ( state) ;
1316
+ buffer. append ( & mut vec ! [ 0 ; realloc] ) ;
1317
+ let mut state =
1318
+ StateWithExtensionsMut :: < Account > :: unpack_after_realloc ( & mut buffer) . unwrap ( ) ;
1319
+ assert_eq ! ( state. base, TEST_ACCOUNT ) ;
1320
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Account as u8 ) ;
1321
+ state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ;
1322
+
1323
+ // account with wrong AccountType
1324
+ let mut buffer = TEST_ACCOUNT_SLICE . to_vec ( ) ;
1325
+ buffer. append ( & mut vec ! [ 1 , 0 ] ) ;
1326
+ let err = StateWithExtensionsMut :: < Account > :: unpack_after_realloc ( & mut buffer) . unwrap_err ( ) ;
1327
+ assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
1328
+
1329
+ // account with pre-existing extension
1330
+ let account_size =
1331
+ ExtensionType :: get_account_len :: < Account > ( & [ ExtensionType :: ImmutableOwner ] ) ;
1332
+ let mut buffer = vec ! [ 0 ; account_size] ;
1333
+ let mut state =
1334
+ StateWithExtensionsMut :: < Account > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1335
+ state. base = TEST_ACCOUNT ;
1336
+ state. pack_base ( ) ;
1337
+ state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ;
1338
+ state. init_account_type ( ) . unwrap ( ) ;
1339
+ let realloc = state
1340
+ . realloc_needed ( ExtensionType :: TransferFeeAmount )
1341
+ . unwrap ( )
1342
+ . unwrap ( ) ;
1343
+ drop ( state) ;
1344
+ buffer. append ( & mut vec ! [ 0 ; realloc] ) ;
1345
+ let mut state =
1346
+ StateWithExtensionsMut :: < Account > :: unpack_after_realloc ( & mut buffer) . unwrap ( ) ;
1347
+ assert_eq ! ( state. base, TEST_ACCOUNT ) ;
1348
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Account as u8 ) ;
1349
+ state. init_extension :: < TransferFeeAmount > ( ) . unwrap ( ) ;
1350
+ assert_eq ! (
1351
+ state. get_extension_types( ) . unwrap( ) ,
1352
+ vec![
1353
+ ExtensionType :: ImmutableOwner ,
1354
+ ExtensionType :: TransferFeeAmount
1355
+ ]
1356
+ ) ;
1357
+
1358
+ // mint
1359
+ let mut buffer = TEST_MINT_SLICE . to_vec ( ) ;
1360
+ let state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
1361
+ let realloc = state
1362
+ . realloc_needed ( ExtensionType :: MintCloseAuthority )
1363
+ . unwrap ( )
1364
+ . unwrap ( ) ;
1365
+ drop ( state) ;
1366
+ buffer. append ( & mut vec ! [ 0 ; realloc] ) ;
1367
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_after_realloc ( & mut buffer) . unwrap ( ) ;
1368
+ assert_eq ! ( state. base, TEST_MINT ) ;
1369
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1370
+ state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1371
+
1372
+ // mint with AccountType
1373
+ let mut buffer = TEST_MINT_SLICE . to_vec ( ) ;
1374
+ buffer. append ( & mut vec ! [ 0 ; Account :: LEN - Mint :: LEN ] ) ;
1375
+ buffer. append ( & mut vec ! [ 1 , 0 ] ) ;
1376
+ let state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
1377
+ assert_eq ! ( state. base, TEST_MINT ) ;
1378
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1379
+ let realloc = state
1380
+ . realloc_needed ( ExtensionType :: MintCloseAuthority )
1381
+ . unwrap ( )
1382
+ . unwrap ( ) ;
1383
+ drop ( state) ;
1384
+ buffer. append ( & mut vec ! [ 0 ; realloc] ) ;
1385
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_after_realloc ( & mut buffer) . unwrap ( ) ;
1386
+ assert_eq ! ( state. base, TEST_MINT ) ;
1387
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1388
+ state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1389
+
1390
+ // mint with wrong AccountType
1391
+ let mut buffer = TEST_MINT_SLICE . to_vec ( ) ;
1392
+ buffer. append ( & mut vec ! [ 0 ; Account :: LEN - Mint :: LEN ] ) ;
1393
+ buffer. append ( & mut vec ! [ 2 , 0 ] ) ;
1394
+ let err = StateWithExtensionsMut :: < Mint > :: unpack_after_realloc ( & mut buffer) . unwrap_err ( ) ;
1395
+ assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
1396
+
1397
+ // mint with pre-existing extension
1398
+ let mut buffer = MINT_WITH_EXTENSION . to_vec ( ) ;
1399
+ let state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
1400
+ assert_eq ! ( state. base, TEST_MINT ) ;
1401
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1402
+ assert_eq ! (
1403
+ state. get_extension_types( ) . unwrap( ) ,
1404
+ vec![ ExtensionType :: MintCloseAuthority ]
1405
+ ) ;
1406
+ let realloc = state
1407
+ . realloc_needed ( ExtensionType :: TransferFeeConfig )
1408
+ . unwrap ( )
1409
+ . unwrap ( ) ;
1410
+ drop ( state) ;
1411
+ buffer. append ( & mut vec ! [ 0 ; realloc] ) ;
1412
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_after_realloc ( & mut buffer) . unwrap ( ) ;
1413
+ assert_eq ! ( state. base, TEST_MINT ) ;
1414
+ assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1415
+ state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1416
+ assert_eq ! (
1417
+ state. get_extension_types( ) . unwrap( ) ,
1418
+ vec![
1419
+ ExtensionType :: MintCloseAuthority ,
1420
+ ExtensionType :: TransferFeeConfig
1421
+ ]
1422
+ ) ;
1423
+ }
1424
+
1273
1425
#[ test]
1274
1426
fn test_get_required_init_account_extensions ( ) {
1275
1427
// Some mint extensions with no required account extensions
0 commit comments