@@ -9,8 +9,10 @@ use core::ops::Index;
99
1010#[ cfg( feature = "arbitrary" ) ]
1111use arbitrary:: { Arbitrary , Unstructured } ;
12+ use encoding:: { Encodable , Encoder } ;
1213#[ cfg( feature = "hex" ) ]
1314use hex:: { error:: HexToBytesError , FromHex } ;
15+ use internals:: array_vec:: ArrayVec ;
1416use internals:: compact_size;
1517use internals:: slice:: SliceExt ;
1618use internals:: wrap_debug:: WrapDebug ;
@@ -260,6 +262,56 @@ fn decode_cursor(bytes: &[u8], start_of_indices: usize, index: usize) -> Option<
260262 bytes. get_array :: < 4 > ( start) . map ( |index_bytes| u32:: from_ne_bytes ( * index_bytes) as usize )
261263}
262264
265+ /// The maximum length of a compact size encoding.
266+ const SIZE : usize = compact_size:: MAX_ENCODING_SIZE ;
267+
268+ /// The encoder for the [`Witness`] type.
269+ // This is basically an exact copy of the `encoding::BytesEncoder` except we prefix
270+ // with the number of witness elements not the byte slice length.
271+ pub struct WitnessEncoder < ' a > {
272+ /// A slice of all the elements without the initial length prefix
273+ /// but with the length prefix on each element.
274+ witness_elements : Option < & ' a [ u8 ] > ,
275+ /// Encoding of the number of witness elements.
276+ num_elements : Option < ArrayVec < u8 , SIZE > > ,
277+ }
278+
279+ impl Encodable for Witness {
280+ type Encoder < ' a >
281+ = WitnessEncoder < ' a >
282+ where
283+ Self : ' a ;
284+
285+ fn encoder ( & self ) -> Self :: Encoder < ' _ > {
286+ let num_elements = Some ( compact_size:: encode ( self . len ( ) ) ) ;
287+ let witness_elements = Some ( & self . content [ ..self . indices_start ] ) ;
288+
289+ WitnessEncoder { witness_elements, num_elements }
290+ }
291+ }
292+
293+ impl < ' a > Encoder for WitnessEncoder < ' a > {
294+ #[ inline]
295+ fn current_chunk ( & self ) -> Option < & [ u8 ] > {
296+ if let Some ( num_elements) = self . num_elements . as_ref ( ) {
297+ Some ( num_elements)
298+ } else {
299+ self . witness_elements
300+ }
301+ }
302+
303+ #[ inline]
304+ fn advance ( & mut self ) -> bool {
305+ if self . num_elements . is_some ( ) {
306+ self . num_elements = None ;
307+ true
308+ } else {
309+ self . witness_elements = None ;
310+ false
311+ }
312+ }
313+ }
314+
263315// Note: we use `Borrow` in the following `PartialEq` impls specifically because of its additional
264316// constraints on equality semantics.
265317impl < T : core:: borrow:: Borrow < [ u8 ] > > PartialEq < [ T ] > for Witness {
@@ -913,4 +965,51 @@ mod test {
913965 let witness = Witness :: from_hex ( hex_strings) . unwrap ( ) ;
914966 assert_eq ! ( witness. len( ) , 2 ) ;
915967 }
968+
969+ #[ test]
970+ fn encode ( ) {
971+ let bytes1 = [ 1u8 , 2 , 3 ] ;
972+ let bytes2 = [ 4u8 , 5 ] ;
973+ let bytes3 = [ 6u8 , 7 , 8 , 9 ] ;
974+ let data = [ & bytes1[ ..] , & bytes2[ ..] , & bytes3[ ..] ] ;
975+
976+ // Use FromIterator directly
977+ let witness = Witness :: from_iter ( data) ;
978+
979+ let want = [ 0x03 , 0x03 , 0x01 , 0x02 , 0x03 , 0x02 , 0x04 , 0x05 , 0x04 , 0x06 , 0x07 , 0x08 , 0x09 ] ;
980+ let got = encoding:: encode_to_vec ( & witness) ;
981+
982+ assert_eq ! ( & got, & want) ;
983+ }
984+
985+ #[ test]
986+ fn encodes_using_correct_chunks ( ) {
987+ let bytes1 = [ 1u8 , 2 , 3 ] ;
988+ let bytes2 = [ 4u8 , 5 ] ;
989+ let data = [ & bytes1[ ..] , & bytes2[ ..] ] ;
990+
991+ // Use FromIterator directly
992+ let witness = Witness :: from_iter ( data) ;
993+
994+ // Should have length prefix chunk, then the content slice, then exhausted.
995+ let mut encoder = witness. encoder ( ) ;
996+
997+ assert_eq ! ( encoder. current_chunk( ) , Some ( & [ 2u8 ] [ ..] ) ) ;
998+ assert ! ( encoder. advance( ) ) ;
999+
1000+ // We don't encode one element at a time, rather we encode the whole content slice at once.
1001+ assert_eq ! ( encoder. current_chunk( ) , Some ( & [ 3u8 , 1 , 2 , 3 , 2 , 4 , 5 ] [ ..] ) ) ;
1002+ assert ! ( !encoder. advance( ) ) ;
1003+ assert_eq ! ( encoder. current_chunk( ) , None ) ;
1004+ }
1005+
1006+ #[ test]
1007+ fn encode_empty ( ) {
1008+ let witness = Witness :: default ( ) ;
1009+
1010+ let want = [ 0x00 ] ;
1011+ let got = encoding:: encode_to_vec ( & witness) ;
1012+
1013+ assert_eq ! ( & got, & want) ;
1014+ }
9161015}
0 commit comments