@@ -413,46 +413,23 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
413
413
}
414
414
}
415
415
416
- fn init_or_get_extension < V : Extension > ( & mut self , init : bool ) -> Result < & mut V , ProgramError > {
416
+ /// Unpack a portion of the TLV data as the desired type that allows modifying the type
417
+ pub fn get_extension_mut < V : Extension > ( & mut self ) -> Result < & mut V , ProgramError > {
417
418
if V :: TYPE . get_account_type ( ) != S :: ACCOUNT_TYPE {
418
419
return Err ( ProgramError :: InvalidAccountData ) ;
419
420
}
420
421
let TlvIndices {
421
422
type_start,
422
423
length_start,
423
424
value_start,
424
- } = get_extension_indices :: < V > ( self . tlv_data , init ) ?;
425
+ } = get_extension_indices :: < V > ( self . tlv_data , false ) ?;
425
426
426
427
if self . tlv_data [ type_start..] . len ( ) < V :: TYPE . get_tlv_len ( ) {
427
428
return Err ( ProgramError :: InvalidAccountData ) ;
428
429
}
429
- if init {
430
- // write extension type
431
- let extension_type_array: [ u8 ; 2 ] = V :: TYPE . into ( ) ;
432
- let extension_type_ref = & mut self . tlv_data [ type_start..length_start] ;
433
- extension_type_ref. copy_from_slice ( & extension_type_array) ;
434
- // write length
435
- let length_ref =
436
- pod_from_bytes_mut :: < Length > ( & mut self . tlv_data [ length_start..value_start] ) ?;
437
- // maybe this becomes smarter later for dynamically sized extensions
438
- let length = pod_get_packed_len :: < V > ( ) ;
439
- * length_ref = Length :: try_from ( length) . unwrap ( ) ;
440
-
441
- let value_end = value_start. saturating_add ( length) ;
442
- let extension_ref =
443
- pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] ) ?;
444
- * extension_ref = V :: default ( ) ;
445
- Ok ( extension_ref)
446
- } else {
447
- let length = pod_from_bytes :: < Length > ( & self . tlv_data [ length_start..value_start] ) ?;
448
- let value_end = value_start. saturating_add ( usize:: from ( * length) ) ;
449
- pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] )
450
- }
451
- }
452
-
453
- /// Unpack a portion of the TLV data as the desired type that allows modifying the type
454
- pub fn get_extension_mut < V : Extension > ( & mut self ) -> Result < & mut V , ProgramError > {
455
- self . init_or_get_extension ( false )
430
+ let length = pod_from_bytes :: < Length > ( & self . tlv_data [ length_start..value_start] ) ?;
431
+ let value_end = value_start. saturating_add ( usize:: from ( * length) ) ;
432
+ pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] )
456
433
}
457
434
458
435
/// Unpack a portion of the TLV data as the desired type
@@ -480,9 +457,48 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
480
457
}
481
458
482
459
/// Packs the default extension data into an open slot if not already found in the
483
- /// data buffer, otherwise overwrites the existing extension with the default state
484
- pub fn init_extension < V : Extension > ( & mut self ) -> Result < & mut V , ProgramError > {
485
- self . init_or_get_extension ( true )
460
+ /// data buffer. If extension is already found in the buffer, it overwrites the existing
461
+ /// extension with the default state if `overwrite` is set. If extension found, but
462
+ /// `overwrite` is not set, it returns error.
463
+ pub fn init_extension < V : Extension > (
464
+ & mut self ,
465
+ overwrite : bool ,
466
+ ) -> Result < & mut V , ProgramError > {
467
+ if V :: TYPE . get_account_type ( ) != S :: ACCOUNT_TYPE {
468
+ return Err ( ProgramError :: InvalidAccountData ) ;
469
+ }
470
+ let TlvIndices {
471
+ type_start,
472
+ length_start,
473
+ value_start,
474
+ } = get_extension_indices :: < V > ( self . tlv_data , true ) ?;
475
+
476
+ if self . tlv_data [ type_start..] . len ( ) < V :: TYPE . get_tlv_len ( ) {
477
+ return Err ( ProgramError :: InvalidAccountData ) ;
478
+ }
479
+ let extension_type = ExtensionType :: try_from ( & self . tlv_data [ type_start..length_start] ) ?;
480
+
481
+ if extension_type == ExtensionType :: Uninitialized || overwrite {
482
+ // write extension type
483
+ let extension_type_array: [ u8 ; 2 ] = V :: TYPE . into ( ) ;
484
+ let extension_type_ref = & mut self . tlv_data [ type_start..length_start] ;
485
+ extension_type_ref. copy_from_slice ( & extension_type_array) ;
486
+ // write length
487
+ let length_ref =
488
+ pod_from_bytes_mut :: < Length > ( & mut self . tlv_data [ length_start..value_start] ) ?;
489
+ // maybe this becomes smarter later for dynamically sized extensions
490
+ let length = pod_get_packed_len :: < V > ( ) ;
491
+ * length_ref = Length :: try_from ( length) . unwrap ( ) ;
492
+
493
+ let value_end = value_start. saturating_add ( length) ;
494
+ let extension_ref =
495
+ pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] ) ?;
496
+ * extension_ref = V :: default ( ) ;
497
+ Ok ( extension_ref)
498
+ } else {
499
+ // extension is already initialized, but no overwrite permission
500
+ Err ( TokenError :: ExtensionAlreadyInitialized . into ( ) )
501
+ }
486
502
}
487
503
488
504
/// If `extension_type` is an Account-associated ExtensionType that requires initialization on
@@ -498,14 +514,14 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
498
514
}
499
515
match extension_type {
500
516
ExtensionType :: TransferFeeAmount => {
501
- self . init_extension :: < TransferFeeAmount > ( ) . map ( |_| ( ) )
517
+ self . init_extension :: < TransferFeeAmount > ( true ) . map ( |_| ( ) )
502
518
}
503
519
// ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
504
520
// on InitializeAccount
505
521
ExtensionType :: ConfidentialTransferAccount => Ok ( ( ) ) ,
506
522
#[ cfg( test) ]
507
523
ExtensionType :: AccountPaddingTest => {
508
- self . init_extension :: < AccountPaddingTest > ( ) . map ( |_| ( ) )
524
+ self . init_extension :: < AccountPaddingTest > ( true ) . map ( |_| ( ) )
509
525
}
510
526
_ => unreachable ! ( ) ,
511
527
}
@@ -932,19 +948,27 @@ mod test {
932
948
let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
933
949
// fail init account extension
934
950
assert_eq ! (
935
- state. init_extension:: <TransferFeeAmount >( ) ,
951
+ state. init_extension:: <TransferFeeAmount >( true ) ,
936
952
Err ( ProgramError :: InvalidAccountData ) ,
937
953
) ;
938
954
939
955
// success write extension
940
956
let close_authority = OptionalNonZeroPubkey :: try_from ( Some ( Pubkey :: new ( & [ 1 ; 32 ] ) ) ) . unwrap ( ) ;
941
- let extension = state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
957
+ let extension = state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
942
958
extension. close_authority = close_authority;
943
959
assert_eq ! (
944
960
& state. get_extension_types( ) . unwrap( ) ,
945
961
& [ ExtensionType :: MintCloseAuthority ]
946
962
) ;
947
963
964
+ // fail init extension when already initialized
965
+ assert_eq ! (
966
+ state. init_extension:: <MintCloseAuthority >( false ) ,
967
+ Err ( ProgramError :: Custom (
968
+ TokenError :: ExtensionAlreadyInitialized as u32
969
+ ) )
970
+ ) ;
971
+
948
972
// fail unpack as account, a mint extension was written
949
973
assert_eq ! (
950
974
StateWithExtensionsMut :: <Account >:: unpack_uninitialized( & mut buffer) ,
@@ -1030,7 +1054,7 @@ mod test {
1030
1054
let mut state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
1031
1055
// init one more extension
1032
1056
let mint_transfer_fee = test_transfer_fee_config ( ) ;
1033
- let new_extension = state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1057
+ let new_extension = state. init_extension :: < TransferFeeConfig > ( true ) . unwrap ( ) ;
1034
1058
new_extension. transfer_fee_config_authority =
1035
1059
mint_transfer_fee. transfer_fee_config_authority ;
1036
1060
new_extension. withdraw_withheld_authority = mint_transfer_fee. withdraw_withheld_authority ;
@@ -1063,7 +1087,7 @@ mod test {
1063
1087
// fail to init one more extension that does not fit
1064
1088
let mut state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
1065
1089
assert_eq ! (
1066
- state. init_extension:: <MintPaddingTest >( ) ,
1090
+ state. init_extension:: <MintPaddingTest >( true ) ,
1067
1091
Err ( ProgramError :: InvalidAccountData ) ,
1068
1092
) ;
1069
1093
}
@@ -1079,11 +1103,11 @@ mod test {
1079
1103
let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1080
1104
// write extensions
1081
1105
let close_authority = OptionalNonZeroPubkey :: try_from ( Some ( Pubkey :: new ( & [ 1 ; 32 ] ) ) ) . unwrap ( ) ;
1082
- let extension = state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1106
+ let extension = state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
1083
1107
extension. close_authority = close_authority;
1084
1108
1085
1109
let mint_transfer_fee = test_transfer_fee_config ( ) ;
1086
- let extension = state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1110
+ let extension = state. init_extension :: < TransferFeeConfig > ( true ) . unwrap ( ) ;
1087
1111
extension. transfer_fee_config_authority = mint_transfer_fee. transfer_fee_config_authority ;
1088
1112
extension. withdraw_withheld_authority = mint_transfer_fee. withdraw_withheld_authority ;
1089
1113
extension. withheld_amount = mint_transfer_fee. withheld_amount ;
@@ -1115,15 +1139,15 @@ mod test {
1115
1139
1116
1140
// write extensions in a different order
1117
1141
let mint_transfer_fee = test_transfer_fee_config ( ) ;
1118
- let extension = state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1142
+ let extension = state. init_extension :: < TransferFeeConfig > ( true ) . unwrap ( ) ;
1119
1143
extension. transfer_fee_config_authority = mint_transfer_fee. transfer_fee_config_authority ;
1120
1144
extension. withdraw_withheld_authority = mint_transfer_fee. withdraw_withheld_authority ;
1121
1145
extension. withheld_amount = mint_transfer_fee. withheld_amount ;
1122
1146
extension. older_transfer_fee = mint_transfer_fee. older_transfer_fee ;
1123
1147
extension. newer_transfer_fee = mint_transfer_fee. newer_transfer_fee ;
1124
1148
1125
1149
let close_authority = OptionalNonZeroPubkey :: try_from ( Some ( Pubkey :: new ( & [ 1 ; 32 ] ) ) ) . unwrap ( ) ;
1126
- let extension = state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1150
+ let extension = state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
1127
1151
extension. close_authority = close_authority;
1128
1152
1129
1153
assert_eq ! (
@@ -1169,7 +1193,7 @@ mod test {
1169
1193
state. init_account_type ( ) . unwrap ( ) ;
1170
1194
1171
1195
// write padding
1172
- let extension = state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1196
+ let extension = state. init_extension :: < MintPaddingTest > ( true ) . unwrap ( ) ;
1173
1197
extension. padding1 = [ 1 ; 128 ] ;
1174
1198
extension. padding2 = [ 1 ; 48 ] ;
1175
1199
extension. padding3 = [ 1 ; 9 ] ;
@@ -1206,12 +1230,12 @@ mod test {
1206
1230
StateWithExtensionsMut :: < Account > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1207
1231
// fail init mint extension
1208
1232
assert_eq ! (
1209
- state. init_extension:: <TransferFeeConfig >( ) ,
1233
+ state. init_extension:: <TransferFeeConfig >( true ) ,
1210
1234
Err ( ProgramError :: InvalidAccountData ) ,
1211
1235
) ;
1212
1236
// success write extension
1213
1237
let withheld_amount = PodU64 :: from ( u64:: MAX ) ;
1214
- let extension = state. init_extension :: < TransferFeeAmount > ( ) . unwrap ( ) ;
1238
+ let extension = state. init_extension :: < TransferFeeAmount > ( true ) . unwrap ( ) ;
1215
1239
extension. withheld_amount = withheld_amount;
1216
1240
1217
1241
assert_eq ! (
@@ -1305,7 +1329,7 @@ mod test {
1305
1329
state. init_account_type ( ) . unwrap ( ) ;
1306
1330
1307
1331
// write padding
1308
- let extension = state. init_extension :: < AccountPaddingTest > ( ) . unwrap ( ) ;
1332
+ let extension = state. init_extension :: < AccountPaddingTest > ( true ) . unwrap ( ) ;
1309
1333
extension. 0 . padding1 = [ 2 ; 128 ] ;
1310
1334
extension. 0 . padding2 = [ 2 ; 48 ] ;
1311
1335
extension. 0 . padding3 = [ 2 ; 9 ] ;
@@ -1341,7 +1365,7 @@ mod test {
1341
1365
let mut state = StateWithExtensionsMut :: < Account > :: unpack ( & mut buffer) . unwrap ( ) ;
1342
1366
assert_eq ! ( state. base, TEST_ACCOUNT ) ;
1343
1367
assert_eq ! ( state. account_type[ 0 ] , AccountType :: Account as u8 ) ;
1344
- state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ; // just confirming initialization works
1368
+ state. init_extension :: < ImmutableOwner > ( true ) . unwrap ( ) ; // just confirming initialization works
1345
1369
1346
1370
// account with buffer big enough for AccountType only
1347
1371
let mut buffer = TEST_ACCOUNT_SLICE . to_vec ( ) ;
@@ -1384,7 +1408,7 @@ mod test {
1384
1408
let mut state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
1385
1409
assert_eq ! ( state. base, TEST_MINT ) ;
1386
1410
assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1387
- state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1411
+ state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
1388
1412
1389
1413
// mint with buffer big enough for AccountType only
1390
1414
let mut buffer = TEST_MINT_SLICE . to_vec ( ) ;
@@ -1499,7 +1523,7 @@ mod test {
1499
1523
1500
1524
// fail init extension
1501
1525
assert_eq ! (
1502
- state. init_extension:: <TransferFeeConfig >( ) ,
1526
+ state. init_extension:: <TransferFeeConfig >( true ) ,
1503
1527
Err ( ProgramError :: InvalidAccountData ) ,
1504
1528
) ;
1505
1529
@@ -1514,7 +1538,7 @@ mod test {
1514
1538
state. base = TEST_MINT ;
1515
1539
state. pack_base ( ) ;
1516
1540
state. init_account_type ( ) . unwrap ( ) ;
1517
- let extension = state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1541
+ let extension = state. init_extension :: < MintPaddingTest > ( true ) . unwrap ( ) ;
1518
1542
assert_eq ! ( extension. padding1, [ 1 ; 128 ] ) ;
1519
1543
assert_eq ! ( extension. padding2, [ 2 ; 48 ] ) ;
1520
1544
assert_eq ! ( extension. padding3, [ 3 ; 9 ] ) ;
@@ -1526,7 +1550,9 @@ mod test {
1526
1550
ExtensionType :: get_account_len :: < Mint > ( & [ ExtensionType :: MintCloseAuthority ] ) ;
1527
1551
let mut buffer = vec ! [ 0 ; mint_size - 1 ] ;
1528
1552
let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1529
- let err = state. init_extension :: < MintCloseAuthority > ( ) . unwrap_err ( ) ;
1553
+ let err = state
1554
+ . init_extension :: < MintCloseAuthority > ( true )
1555
+ . unwrap_err ( ) ;
1530
1556
assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
1531
1557
1532
1558
state. tlv_data [ 0 ] = 3 ;
@@ -1556,7 +1582,7 @@ mod test {
1556
1582
state. base = TEST_ACCOUNT ;
1557
1583
state. pack_base ( ) ;
1558
1584
state. init_account_type ( ) . unwrap ( ) ;
1559
- state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ;
1585
+ state. init_extension :: < ImmutableOwner > ( true ) . unwrap ( ) ;
1560
1586
1561
1587
assert_eq ! (
1562
1588
get_first_extension_type( state. tlv_data) . unwrap( ) ,
0 commit comments