@@ -505,13 +505,21 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
505
505
& self ,
506
506
new_extension : ExtensionType ,
507
507
) -> Result < Option < usize > , ProgramError > {
508
- let current_extensions = self . get_extension_types ( ) ?;
509
- let needed_tlv_len = ExtensionType :: get_total_tlv_len ( & current_extensions) ;
510
- let new_needed_tlv_len = needed_tlv_len. saturating_add ( new_extension. get_type_len ( ) ) ;
508
+ let mut extensions = self . get_extension_types ( ) ?;
509
+ if !extensions. contains ( & new_extension) {
510
+ extensions. push ( new_extension) ;
511
+ }
512
+ let new_needed_tlv_len = ExtensionType :: get_total_tlv_len ( & extensions) ;
511
513
if self . tlv_data . len ( ) >= new_needed_tlv_len {
512
514
Ok ( None )
513
515
} else {
514
- Ok ( Some ( new_needed_tlv_len - self . tlv_data . len ( ) ) ) // arithmetic safe because of if clause
516
+ let mut diff = new_needed_tlv_len - self . tlv_data . len ( ) ; // arithmetic safe because of if clause
517
+ if self . account_type . is_empty ( ) {
518
+ diff = diff
519
+ . saturating_add ( size_of :: < AccountType > ( ) )
520
+ . saturating_add ( BASE_ACCOUNT_LENGTH . saturating_sub ( S :: LEN ) ) ;
521
+ }
522
+ Ok ( Some ( diff) )
515
523
}
516
524
}
517
525
}
@@ -611,7 +619,16 @@ impl ExtensionType {
611
619
612
620
/// Get the TLV length for a set of ExtensionTypes
613
621
fn get_total_tlv_len ( extension_types : & [ Self ] ) -> usize {
614
- extension_types. iter ( ) . map ( |e| e. get_tlv_len ( ) ) . sum ( )
622
+ let tlv_len: usize = extension_types. iter ( ) . map ( |e| e. get_tlv_len ( ) ) . sum ( ) ;
623
+ if tlv_len
624
+ == Multisig :: LEN
625
+ . saturating_sub ( BASE_ACCOUNT_LENGTH )
626
+ . saturating_sub ( size_of :: < AccountType > ( ) )
627
+ {
628
+ tlv_len. saturating_add ( size_of :: < ExtensionType > ( ) )
629
+ } else {
630
+ tlv_len
631
+ }
615
632
}
616
633
617
634
/// Get the required account data length for the given ExtensionTypes
@@ -620,14 +637,9 @@ impl ExtensionType {
620
637
S :: LEN
621
638
} else {
622
639
let extension_size = Self :: get_total_tlv_len ( extension_types) ;
623
- let account_size = extension_size
640
+ extension_size
624
641
. saturating_add ( BASE_ACCOUNT_LENGTH )
625
- . saturating_add ( size_of :: < AccountType > ( ) ) ;
626
- if account_size == Multisig :: LEN {
627
- account_size. saturating_add ( size_of :: < ExtensionType > ( ) )
628
- } else {
629
- account_size
630
- }
642
+ . saturating_add ( size_of :: < AccountType > ( ) )
631
643
}
632
644
}
633
645
@@ -1373,6 +1385,60 @@ mod test {
1373
1385
1374
1386
#[ test]
1375
1387
fn test_realloc_needed ( ) {
1388
+ // buffer exact size of base-state account
1389
+ let account_size = ExtensionType :: get_account_len :: < Account > ( & [ ] ) ;
1390
+ let mut buffer = vec ! [ 0 ; account_size] ;
1391
+ let mut state =
1392
+ StateWithExtensionsMut :: < Account > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1393
+ state. base = TEST_ACCOUNT ;
1394
+ state. pack_base ( ) ;
1395
+ state. init_account_type ( ) . unwrap ( ) ;
1396
+ let realloc = state. realloc_needed ( ExtensionType :: ImmutableOwner ) . unwrap ( ) ;
1397
+ assert_eq ! (
1398
+ realloc,
1399
+ Some ( ExtensionType :: ImmutableOwner . get_tlv_len( ) + size_of:: <AccountType >( ) )
1400
+ ) ;
1401
+ assert_eq ! (
1402
+ account_size + realloc. unwrap( ) ,
1403
+ ExtensionType :: get_account_len:: <Account >( & [ ExtensionType :: ImmutableOwner ] )
1404
+ ) ;
1405
+ let mut buffer = vec ! [ 0 ; account_size + realloc. unwrap( ) ] ;
1406
+ let mut state =
1407
+ StateWithExtensionsMut :: < Account > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1408
+ state. base = TEST_ACCOUNT ;
1409
+ state. pack_base ( ) ;
1410
+ state. init_account_type ( ) . unwrap ( ) ;
1411
+ state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ;
1412
+
1413
+ // buffer exact size of base-state mint
1414
+ let account_size = ExtensionType :: get_account_len :: < Mint > ( & [ ] ) ;
1415
+ let mut buffer = vec ! [ 0 ; account_size] ;
1416
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1417
+ state. base = TEST_MINT ;
1418
+ state. pack_base ( ) ;
1419
+ state. init_account_type ( ) . unwrap ( ) ;
1420
+ let realloc = state
1421
+ . realloc_needed ( ExtensionType :: MintCloseAuthority )
1422
+ . unwrap ( ) ;
1423
+ assert_eq ! (
1424
+ realloc,
1425
+ Some (
1426
+ ExtensionType :: MintCloseAuthority . get_tlv_len( )
1427
+ + size_of:: <AccountType >( )
1428
+ + ( Account :: LEN - Mint :: LEN )
1429
+ )
1430
+ ) ;
1431
+ assert_eq ! (
1432
+ account_size + realloc. unwrap( ) ,
1433
+ ExtensionType :: get_account_len:: <Mint >( & [ ExtensionType :: MintCloseAuthority ] )
1434
+ ) ;
1435
+ let mut buffer = vec ! [ 0 ; account_size + realloc. unwrap( ) ] ;
1436
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1437
+ state. base = TEST_MINT ;
1438
+ state. pack_base ( ) ;
1439
+ state. init_account_type ( ) . unwrap ( ) ;
1440
+ state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1441
+
1376
1442
// buffer exact size of existing extension
1377
1443
let mint_size = ExtensionType :: get_account_len :: < Mint > ( & [ ExtensionType :: TransferFeeConfig ] ) ;
1378
1444
let mut buffer = vec ! [ 0 ; mint_size] ;
@@ -1387,12 +1453,27 @@ mod test {
1387
1453
None
1388
1454
) ;
1389
1455
state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1456
+ let realloc = state
1457
+ . realloc_needed ( ExtensionType :: MintCloseAuthority )
1458
+ . unwrap ( ) ;
1390
1459
assert_eq ! (
1391
- state
1392
- . realloc_needed( ExtensionType :: MintCloseAuthority )
1393
- . unwrap( ) ,
1394
- Some ( ExtensionType :: MintCloseAuthority . get_type_len( ) )
1460
+ realloc,
1461
+ Some ( ExtensionType :: MintCloseAuthority . get_tlv_len( ) )
1395
1462
) ;
1463
+ assert_eq ! (
1464
+ mint_size + realloc. unwrap( ) ,
1465
+ ExtensionType :: get_account_len:: <Account >( & [
1466
+ ExtensionType :: TransferFeeConfig ,
1467
+ ExtensionType :: MintCloseAuthority
1468
+ ] )
1469
+ ) ;
1470
+ let mut buffer = vec ! [ 0 ; mint_size + realloc. unwrap( ) ] ;
1471
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1472
+ state. base = TEST_MINT ;
1473
+ state. pack_base ( ) ;
1474
+ state. init_account_type ( ) . unwrap ( ) ;
1475
+ state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1476
+ state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1396
1477
1397
1478
// buffer with multisig len
1398
1479
let mint_size = ExtensionType :: get_account_len :: < Mint > ( & [ ExtensionType :: MintPaddingTest ] ) ;
@@ -1408,12 +1489,27 @@ mod test {
1408
1489
None
1409
1490
) ;
1410
1491
state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1492
+ let realloc = state
1493
+ . realloc_needed ( ExtensionType :: MintCloseAuthority )
1494
+ . unwrap ( ) ;
1411
1495
assert_eq ! (
1412
- state
1413
- . realloc_needed( ExtensionType :: MintCloseAuthority )
1414
- . unwrap( ) ,
1415
- Some ( ExtensionType :: MintCloseAuthority . get_type_len( ) - size_of:: <ExtensionType >( ) )
1496
+ realloc,
1497
+ Some ( ExtensionType :: MintCloseAuthority . get_tlv_len( ) - size_of:: <ExtensionType >( ) )
1498
+ ) ;
1499
+ assert_eq ! (
1500
+ mint_size + realloc. unwrap( ) ,
1501
+ ExtensionType :: get_account_len:: <Account >( & [
1502
+ ExtensionType :: MintPaddingTest ,
1503
+ ExtensionType :: MintCloseAuthority
1504
+ ] )
1416
1505
) ;
1506
+ let mut buffer = vec ! [ 0 ; mint_size + realloc. unwrap( ) ] ;
1507
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1508
+ state. base = TEST_MINT ;
1509
+ state. pack_base ( ) ;
1510
+ state. init_account_type ( ) . unwrap ( ) ;
1511
+ state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1512
+ state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1417
1513
1418
1514
// huge buffer
1419
1515
let mut buffer = vec ! [ 0 ; u16 :: MAX . into( ) ] ;
0 commit comments