@@ -491,6 +491,24 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
491
491
fn get_first_extension_type ( & self ) -> Result < Option < ExtensionType > , ProgramError > {
492
492
get_first_extension_type ( self . tlv_data )
493
493
}
494
+
495
+ /// Compares the length of an extension with the currently used TLV buffer to determine if
496
+ /// reallocation is needed. If so, returns Some(v), where v is the difference between current
497
+ /// space and needed.
498
+ #[ allow( dead_code) ]
499
+ pub ( crate ) fn realloc_needed (
500
+ & self ,
501
+ new_extension : ExtensionType ,
502
+ ) -> Result < Option < usize > , ProgramError > {
503
+ let current_extensions = self . get_extension_types ( ) ?;
504
+ let needed_tlv_len = ExtensionType :: get_total_tlv_len ( & current_extensions) ;
505
+ let new_needed_tlv_len = needed_tlv_len. saturating_add ( new_extension. get_type_len ( ) ) ;
506
+ if self . tlv_data . len ( ) >= new_needed_tlv_len {
507
+ Ok ( None )
508
+ } else {
509
+ Ok ( Some ( new_needed_tlv_len - self . tlv_data . len ( ) ) ) // arithmetic safe because of if clause
510
+ }
511
+ }
494
512
}
495
513
496
514
/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig` types
@@ -1344,4 +1362,69 @@ mod test {
1344
1362
1345
1363
assert_eq ! ( state. get_extension_types( ) . unwrap( ) , vec![ ] ) ;
1346
1364
}
1365
+
1366
+ #[ test]
1367
+ fn test_realloc_needed ( ) {
1368
+ // buffer exact size of existing extension
1369
+ let mint_size = ExtensionType :: get_account_len :: < Mint > ( & [ ExtensionType :: TransferFeeConfig ] ) ;
1370
+ let mut buffer = vec ! [ 0 ; mint_size] ;
1371
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1372
+ state. base = TEST_MINT ;
1373
+ state. pack_base ( ) ;
1374
+ state. init_account_type ( ) . unwrap ( ) ;
1375
+ assert_eq ! (
1376
+ state
1377
+ . realloc_needed( ExtensionType :: TransferFeeConfig )
1378
+ . unwrap( ) ,
1379
+ None
1380
+ ) ;
1381
+ state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1382
+ assert_eq ! (
1383
+ state
1384
+ . realloc_needed( ExtensionType :: MintCloseAuthority )
1385
+ . unwrap( ) ,
1386
+ Some ( ExtensionType :: MintCloseAuthority . get_type_len( ) )
1387
+ ) ;
1388
+
1389
+ // buffer with multisig len
1390
+ let mint_size = ExtensionType :: get_account_len :: < Mint > ( & [ ExtensionType :: MintPaddingTest ] ) ;
1391
+ let mut buffer = vec ! [ 0 ; mint_size] ;
1392
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1393
+ state. base = TEST_MINT ;
1394
+ state. pack_base ( ) ;
1395
+ state. init_account_type ( ) . unwrap ( ) ;
1396
+ assert_eq ! (
1397
+ state
1398
+ . realloc_needed( ExtensionType :: MintPaddingTest )
1399
+ . unwrap( ) ,
1400
+ None
1401
+ ) ;
1402
+ state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1403
+ assert_eq ! (
1404
+ state
1405
+ . realloc_needed( ExtensionType :: MintCloseAuthority )
1406
+ . unwrap( ) ,
1407
+ Some ( ExtensionType :: MintCloseAuthority . get_type_len( ) - size_of:: <ExtensionType >( ) )
1408
+ ) ;
1409
+
1410
+ // huge buffer
1411
+ let mut buffer = vec ! [ 0 ; u16 :: MAX . into( ) ] ;
1412
+ let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1413
+ state. base = TEST_MINT ;
1414
+ state. pack_base ( ) ;
1415
+ state. init_account_type ( ) . unwrap ( ) ;
1416
+ assert_eq ! (
1417
+ state
1418
+ . realloc_needed( ExtensionType :: TransferFeeConfig )
1419
+ . unwrap( ) ,
1420
+ None
1421
+ ) ;
1422
+ state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1423
+ assert_eq ! (
1424
+ state
1425
+ . realloc_needed( ExtensionType :: MintCloseAuthority )
1426
+ . unwrap( ) ,
1427
+ None
1428
+ ) ;
1429
+ }
1347
1430
}
0 commit comments