@@ -9,7 +9,7 @@ use std::fmt;
99use bitcoin:: address:: FromScriptError ;
1010use bitcoin:: psbt:: Psbt ;
1111use bitcoin:: transaction:: InputWeightPrediction ;
12- use bitcoin:: { bip32, psbt, Address , AddressType , Network , TxIn , TxOut , Weight } ;
12+ use bitcoin:: { bip32, psbt, Address , AddressType , Network , TapSighashType , TxIn , TxOut , Weight } ;
1313
1414#[ derive( Debug , PartialEq ) ]
1515pub ( crate ) enum InconsistentPsbt {
@@ -207,7 +207,27 @@ impl InternalInputPair<'_> {
207207 }
208208 P2wpkh => Ok ( InputWeightPrediction :: P2WPKH_MAX ) ,
209209 P2wsh => Err ( InputWeightError :: NotSupported ) ,
210- P2tr => Ok ( InputWeightPrediction :: P2TR_KEY_DEFAULT_SIGHASH ) ,
210+ P2tr => {
211+ match self . psbtin . tap_key_sig {
212+ // An input weight can only be predicted for a taproot key spend
213+ None => Err ( InputWeightError :: NotSupported ) ,
214+ Some ( tap_key_sig) => {
215+ // There's a chance that this input contains both a key path and a script path,
216+ // but we can only predict taproot key spends
217+ if !self . psbtin . tap_scripts . is_empty ( )
218+ || !self . psbtin . tap_script_sigs . is_empty ( )
219+ || self . psbtin . tap_merkle_root . is_some ( )
220+ {
221+ return Err ( InputWeightError :: NotSupported ) ;
222+ }
223+ match tap_key_sig. sighash_type {
224+ TapSighashType :: Default =>
225+ Ok ( InputWeightPrediction :: P2TR_KEY_DEFAULT_SIGHASH ) ,
226+ _ => Ok ( InputWeightPrediction :: P2TR_KEY_NON_DEFAULT_SIGHASH ) ,
227+ }
228+ }
229+ }
230+ }
211231 _ => Err ( AddressTypeError :: UnknownAddressType . into ( ) ) ,
212232 } ?;
213233
@@ -376,3 +396,132 @@ impl std::error::Error for InputWeightError {
376396impl From < AddressTypeError > for InputWeightError {
377397 fn from ( value : AddressTypeError ) -> Self { Self :: AddressType ( value) }
378398}
399+
400+ #[ cfg( test) ]
401+ mod tests {
402+ use bitcoin:: key:: Secp256k1 ;
403+ use bitcoin:: taproot:: { ControlBlock , LeafVersion } ;
404+ use bitcoin:: { psbt, secp256k1, taproot, PublicKey , ScriptBuf , TapNodeHash , XOnlyPublicKey } ;
405+
406+ use super :: * ;
407+ use crate :: core:: psbt:: InternalInputPair ;
408+ use crate :: receive:: InputPair ;
409+
410+ /// Lengths of txid, index and sequence: (32, 4, 4)
411+ const TXID_INDEX_SEQUENCE_WEIGHT : Weight = Weight :: from_non_witness_data_size ( 32 + 4 + 4 ) ;
412+
413+ #[ test]
414+ fn expected_weight_for_p2tr ( ) {
415+ let pubkey_string = "0347ff3dacd07a1f43805ec6808e801505a6e18245178609972a68afbc2777ff2b" ;
416+ let pubkey = pubkey_string. parse :: < PublicKey > ( ) . expect ( "valid pubkey" ) ;
417+ let xonly_pubkey = XOnlyPublicKey :: from ( pubkey. inner ) ;
418+ let p2tr_utxo = TxOut {
419+ value : Default :: default ( ) ,
420+ script_pubkey : ScriptBuf :: new_p2tr ( & Secp256k1 :: new ( ) , xonly_pubkey, None ) ,
421+ } ;
422+ let default_sighash_pair = InputPair {
423+ txin : Default :: default ( ) ,
424+ psbtin : psbt:: Input {
425+ tap_key_sig : Some (
426+ taproot:: Signature :: from_slice (
427+ & [ 0 ; secp256k1:: constants:: SCHNORR_SIGNATURE_SIZE ] ,
428+ )
429+ . unwrap ( ) ,
430+ ) ,
431+ witness_utxo : Some ( p2tr_utxo. clone ( ) ) ,
432+ ..Default :: default ( )
433+ } ,
434+ } ;
435+ assert_eq ! (
436+ InternalInputPair :: from( & default_sighash_pair) . expected_input_weight( ) . unwrap( ) ,
437+ InputWeightPrediction :: P2TR_KEY_DEFAULT_SIGHASH . weight( ) + TXID_INDEX_SEQUENCE_WEIGHT
438+ ) ;
439+
440+ // Add a sighash byte
441+ let mut sig_bytes = [ 0 ; secp256k1:: constants:: SCHNORR_SIGNATURE_SIZE + 1 ] ;
442+ sig_bytes[ sig_bytes. len ( ) - 1 ] = 1 ;
443+ let non_default_sighash_pair = InputPair {
444+ txin : Default :: default ( ) ,
445+ psbtin : psbt:: Input {
446+ tap_key_sig : Some ( taproot:: Signature :: from_slice ( & sig_bytes) . unwrap ( ) ) ,
447+ witness_utxo : Some ( p2tr_utxo) ,
448+ ..Default :: default ( )
449+ } ,
450+ } ;
451+ assert_eq ! (
452+ InternalInputPair :: from( & non_default_sighash_pair) . expected_input_weight( ) . unwrap( ) ,
453+ InputWeightPrediction :: P2TR_KEY_NON_DEFAULT_SIGHASH . weight( )
454+ + TXID_INDEX_SEQUENCE_WEIGHT
455+ ) ;
456+ }
457+
458+ #[ test]
459+ fn not_supported_p2tr_expected_weights ( ) {
460+ let pubkey_string = "0347ff3dacd07a1f43805ec6808e801505a6e18245178609972a68afbc2777ff2b" ;
461+ let pubkey = pubkey_string. parse :: < PublicKey > ( ) . expect ( "valid pubkey" ) ;
462+ let xonly_pubkey = XOnlyPublicKey :: from ( pubkey. inner ) ;
463+ let p2tr_script = ScriptBuf :: new_p2tr ( & Secp256k1 :: new ( ) , xonly_pubkey. clone ( ) , None ) ;
464+ let p2tr_utxo = TxOut { value : Default :: default ( ) , script_pubkey : p2tr_script. clone ( ) } ;
465+
466+ let mut tap_scripts = BTreeMap :: new ( ) ;
467+ let leaf_version: u8 = 0xC0 ;
468+ let mut control_block_vec = Vec :: with_capacity ( 33 ) ;
469+ control_block_vec. push ( leaf_version) ;
470+ control_block_vec. extend_from_slice ( & xonly_pubkey. serialize ( ) ) ;
471+ let control_block = ControlBlock :: decode ( control_block_vec. as_slice ( ) ) . unwrap ( ) ;
472+ tap_scripts
473+ . insert ( control_block. clone ( ) , ( p2tr_script. clone ( ) , control_block. leaf_version ) ) ;
474+
475+ let pair_with_tapscripts = InputPair {
476+ txin : Default :: default ( ) ,
477+ psbtin : psbt:: Input {
478+ tap_scripts,
479+ witness_utxo : Some ( p2tr_utxo. clone ( ) ) ,
480+ ..Default :: default ( )
481+ } ,
482+ } ;
483+ assert_eq ! (
484+ InternalInputPair :: from( & pair_with_tapscripts) . expected_input_weight( ) . err( ) . unwrap( ) ,
485+ InputWeightError :: NotSupported
486+ ) ;
487+
488+ let mut tap_script_sigs = BTreeMap :: new ( ) ;
489+ tap_script_sigs. insert (
490+ ( xonly_pubkey. clone ( ) , p2tr_script. tapscript_leaf_hash ( ) ) ,
491+ taproot:: Signature :: from_slice ( & [ 0 ; secp256k1:: constants:: SCHNORR_SIGNATURE_SIZE ] )
492+ . unwrap ( ) ,
493+ ) ;
494+ let pair_with_tap_script_sigs = InputPair {
495+ txin : Default :: default ( ) ,
496+ psbtin : psbt:: Input {
497+ tap_script_sigs,
498+ witness_utxo : Some ( p2tr_utxo. clone ( ) ) ,
499+ ..Default :: default ( )
500+ } ,
501+ } ;
502+ assert_eq ! (
503+ InternalInputPair :: from( & pair_with_tap_script_sigs)
504+ . expected_input_weight( )
505+ . err( )
506+ . unwrap( ) ,
507+ InputWeightError :: NotSupported
508+ ) ;
509+
510+ let tap_merkle_root = TapNodeHash :: from_script ( & p2tr_script, LeafVersion :: TapScript ) ;
511+ let pair_with_tap_merkle_root = InputPair {
512+ txin : Default :: default ( ) ,
513+ psbtin : psbt:: Input {
514+ tap_merkle_root : Some ( tap_merkle_root) ,
515+ witness_utxo : Some ( p2tr_utxo. clone ( ) ) ,
516+ ..Default :: default ( )
517+ } ,
518+ } ;
519+ assert_eq ! (
520+ InternalInputPair :: from( & pair_with_tap_merkle_root)
521+ . expected_input_weight( )
522+ . err( )
523+ . unwrap( ) ,
524+ InputWeightError :: NotSupported
525+ ) ;
526+ }
527+ }
0 commit comments