@@ -4,8 +4,8 @@ use rand::{rng, Rng};
44
55use crate :: {
66 attestation:: {
7- Attestation , AttestationConfig , Body , EncodingCommitment , FieldId , FieldKind , Header ,
8- ServerCertCommitment , VERSION ,
7+ Attestation , AttestationConfig , Body , EncodingCommitment , Extension , FieldId , FieldKind ,
8+ Header , ServerCertCommitment , VERSION ,
99 } ,
1010 connection:: { ConnectionInfo , ServerEphemKey } ,
1111 hash:: { HashAlgId , TypedHash } ,
@@ -17,8 +17,10 @@ use crate::{
1717} ;
1818
1919/// Attestation builder state for accepting a request.
20+ #[ derive( Debug ) ]
2021pub struct Accept { }
2122
23+ #[ derive( Debug ) ]
2224pub struct Sign {
2325 signature_alg : SignatureAlgId ,
2426 hash_alg : HashAlgId ,
@@ -27,9 +29,11 @@ pub struct Sign {
2729 cert_commitment : ServerCertCommitment ,
2830 encoding_commitment_root : Option < TypedHash > ,
2931 encoder_secret : Option < EncoderSecret > ,
32+ extensions : Vec < Extension > ,
3033}
3134
3235/// An attestation builder.
36+ #[ derive( Debug ) ]
3337pub struct AttestationBuilder < ' a , T = Accept > {
3438 config : & ' a AttestationConfig ,
3539 state : T ,
@@ -56,6 +60,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
5660 hash_alg,
5761 server_cert_commitment : cert_commitment,
5862 encoding_commitment_root,
63+ extensions,
5964 } = request;
6065
6166 if !config. supported_signature_algs ( ) . contains ( & signature_alg) {
@@ -83,6 +88,11 @@ impl<'a> AttestationBuilder<'a, Accept> {
8388 ) ) ;
8489 }
8590
91+ if let Some ( validator) = config. extension_validator ( ) {
92+ validator ( & extensions)
93+ . map_err ( |err| AttestationBuilderError :: new ( ErrorKind :: Extension , err) ) ?;
94+ }
95+
8696 Ok ( AttestationBuilder {
8797 config : self . config ,
8898 state : Sign {
@@ -93,6 +103,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
93103 cert_commitment,
94104 encoding_commitment_root,
95105 encoder_secret : None ,
106+ extensions,
96107 } ,
97108 } )
98109 }
@@ -117,6 +128,12 @@ impl AttestationBuilder<'_, Sign> {
117128 self
118129 }
119130
131+ /// Adds an extension to the attestation.
132+ pub fn extension ( & mut self , extension : Extension ) -> & mut Self {
133+ self . state . extensions . push ( extension) ;
134+ self
135+ }
136+
120137 /// Builds the attestation.
121138 pub fn build ( self , provider : & CryptoProvider ) -> Result < Attestation , AttestationBuilderError > {
122139 let Sign {
@@ -127,6 +144,7 @@ impl AttestationBuilder<'_, Sign> {
127144 cert_commitment,
128145 encoding_commitment_root,
129146 encoder_secret,
147+ extensions,
130148 } = self . state ;
131149
132150 let hasher = provider. hash . get ( & hash_alg) . map_err ( |_| {
@@ -170,6 +188,10 @@ impl AttestationBuilder<'_, Sign> {
170188 cert_commitment : field_id. next ( cert_commitment) ,
171189 encoding_commitment : encoding_commitment. map ( |commitment| field_id. next ( commitment) ) ,
172190 plaintext_hashes : Default :: default ( ) ,
191+ extensions : extensions
192+ . into_iter ( )
193+ . map ( |extension| field_id. next ( extension) )
194+ . collect ( ) ,
173195 } ;
174196
175197 let header = Header {
@@ -203,6 +225,7 @@ enum ErrorKind {
203225 Config ,
204226 Field ,
205227 Signature ,
228+ Extension ,
206229}
207230
208231impl AttestationBuilderError {
@@ -229,6 +252,7 @@ impl std::fmt::Display for AttestationBuilderError {
229252 ErrorKind :: Config => f. write_str ( "config error" ) ?,
230253 ErrorKind :: Field => f. write_str ( "field error" ) ?,
231254 ErrorKind :: Signature => f. write_str ( "signature error" ) ?,
255+ ErrorKind :: Extension => f. write_str ( "extension error" ) ?,
232256 }
233257
234258 if let Some ( source) = & self . source {
@@ -282,6 +306,7 @@ mod test {
282306 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
283307 connection,
284308 Blake3 :: default ( ) ,
309+ Vec :: new ( ) ,
285310 ) ;
286311
287312 let attestation_config = AttestationConfig :: builder ( )
@@ -306,6 +331,7 @@ mod test {
306331 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
307332 connection,
308333 Blake3 :: default ( ) ,
334+ Vec :: new ( ) ,
309335 ) ;
310336
311337 let attestation_config = AttestationConfig :: builder ( )
@@ -331,6 +357,7 @@ mod test {
331357 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
332358 connection,
333359 Blake3 :: default ( ) ,
360+ Vec :: new ( ) ,
334361 ) ;
335362
336363 let attestation_config = AttestationConfig :: builder ( )
@@ -360,6 +387,7 @@ mod test {
360387 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
361388 connection,
362389 Blake3 :: default ( ) ,
390+ Vec :: new ( ) ,
363391 ) ;
364392
365393 let attestation_builder = Attestation :: builder ( attestation_config)
@@ -386,6 +414,7 @@ mod test {
386414 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
387415 connection. clone ( ) ,
388416 Blake3 :: default ( ) ,
417+ Vec :: new ( ) ,
389418 ) ;
390419
391420 let mut attestation_builder = Attestation :: builder ( attestation_config)
@@ -424,6 +453,7 @@ mod test {
424453 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
425454 connection. clone ( ) ,
426455 Blake3 :: default ( ) ,
456+ Vec :: new ( ) ,
427457 ) ;
428458
429459 let mut attestation_builder = Attestation :: builder ( attestation_config)
@@ -455,6 +485,7 @@ mod test {
455485 encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
456486 connection. clone ( ) ,
457487 Blake3 :: default ( ) ,
488+ Vec :: new ( ) ,
458489 ) ;
459490
460491 let mut attestation_builder = Attestation :: builder ( attestation_config)
@@ -477,4 +508,76 @@ mod test {
477508 let err = attestation_builder. build ( crypto_provider) . err ( ) . unwrap ( ) ;
478509 assert ! ( matches!( err. kind, ErrorKind :: Field ) ) ;
479510 }
511+
512+ #[ rstest]
513+ fn test_attestation_builder_reject_extensions_by_default (
514+ attestation_config : & AttestationConfig ,
515+ ) {
516+ let transcript = Transcript :: new ( GET_WITH_HEADER , OK_JSON ) ;
517+ let connection = ConnectionFixture :: tlsnotary ( transcript. length ( ) ) ;
518+
519+ let RequestFixture { request, .. } = request_fixture (
520+ transcript,
521+ encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
522+ connection. clone ( ) ,
523+ Blake3 :: default ( ) ,
524+ vec ! [ Extension {
525+ id: b"foo" . to_vec( ) ,
526+ value: b"bar" . to_vec( ) ,
527+ } ] ,
528+ ) ;
529+
530+ let err = Attestation :: builder ( attestation_config)
531+ . accept_request ( request)
532+ . unwrap_err ( ) ;
533+
534+ assert ! ( matches!( err. kind, ErrorKind :: Extension ) ) ;
535+ }
536+
537+ #[ rstest]
538+ fn test_attestation_builder_accept_extension ( crypto_provider : & CryptoProvider ) {
539+ let attestation_config = AttestationConfig :: builder ( )
540+ . supported_signature_algs ( [ SignatureAlgId :: SECP256K1 ] )
541+ . extension_validator ( |_| Ok ( ( ) ) )
542+ . build ( )
543+ . unwrap ( ) ;
544+
545+ let transcript = Transcript :: new ( GET_WITH_HEADER , OK_JSON ) ;
546+ let connection = ConnectionFixture :: tlsnotary ( transcript. length ( ) ) ;
547+
548+ let RequestFixture { request, .. } = request_fixture (
549+ transcript,
550+ encoding_provider ( GET_WITH_HEADER , OK_JSON ) ,
551+ connection. clone ( ) ,
552+ Blake3 :: default ( ) ,
553+ vec ! [ Extension {
554+ id: b"foo" . to_vec( ) ,
555+ value: b"bar" . to_vec( ) ,
556+ } ] ,
557+ ) ;
558+
559+ let mut attestation_builder = Attestation :: builder ( & attestation_config)
560+ . accept_request ( request)
561+ . unwrap ( ) ;
562+
563+ let ConnectionFixture {
564+ server_cert_data,
565+ connection_info,
566+ ..
567+ } = connection;
568+
569+ let HandshakeData :: V1_2 ( HandshakeDataV1_2 {
570+ server_ephemeral_key,
571+ ..
572+ } ) = server_cert_data. handshake ;
573+
574+ attestation_builder
575+ . connection_info ( connection_info)
576+ . server_ephemeral_key ( server_ephemeral_key)
577+ . encoder_secret ( encoder_secret ( ) ) ;
578+
579+ let attestation = attestation_builder. build ( crypto_provider) . unwrap ( ) ;
580+
581+ assert_eq ! ( attestation. body. extensions( ) . count( ) , 1 ) ;
582+ }
480583}
0 commit comments