11use core:: iter;
22use std:: io;
33
4- use aws_lc_rs:: rand;
4+ use aws_lc_rs:: { cipher :: StreamingDecryptingKey , constant_time , hmac , rand} ;
55use tokio:: io:: AsyncReadExt ;
66use tracing:: debug;
77
@@ -75,25 +75,36 @@ impl From<u8> for MessageType {
7575/// v
7676/// |read|unread and not yet decrypted|
7777/// ```
78- // FIXME implement actual decryption
7978pub ( crate ) struct DecryptingReader < R : AsyncReadExt + Unpin > {
8079 stream : R ,
8180 buf : Vec < u8 > ,
81+ decrypted_buf : Vec < u8 > ,
8282 unread_start : usize ,
83+
84+ packet_number : u32 ,
85+ decryption_key : Option < ( StreamingDecryptingKey , hmac:: Key ) > ,
8386}
8487
8588impl < R : AsyncReadExt + Unpin > DecryptingReader < R > {
8689 pub ( crate ) fn new ( stream : R ) -> Self {
8790 Self {
8891 stream,
8992 buf : Vec :: with_capacity ( 16_384 ) ,
93+ decrypted_buf : Vec :: with_capacity ( 16_384 ) ,
9094 unread_start : 0 ,
95+ packet_number : 0 ,
96+ decryption_key : None ,
9197 }
9298 }
9399
94- async fn ensure_at_least ( & mut self , n : u32 ) -> Result < ( ) , Error > {
95- while self . buf . len ( ) - self . unread_start < n as usize {
96- let read = self . stream . read_buf ( & mut self . buf ) . await ?;
100+ async fn ensure_at_least (
101+ stream : & mut R ,
102+ buf : & mut Vec < u8 > ,
103+ unread_start : & mut usize ,
104+ n : u32 ,
105+ ) -> Result < ( ) , Error > {
106+ while buf. len ( ) - * unread_start < n as usize {
107+ let read = stream. read_buf ( buf) . await ?;
97108 debug ! ( bytes = read, "read from stream" ) ;
98109 if read == 0 {
99110 return Err ( Error :: Io ( io:: Error :: new (
@@ -109,49 +120,149 @@ impl<R: AsyncReadExt + Unpin> DecryptingReader<R> {
109120 ///
110121 /// This should only be used for reading the identification string.
111122 pub ( crate ) async fn read_u8_cleartext ( & mut self ) -> Result < u8 , Error > {
112- self . ensure_at_least ( 1 ) . await ?;
123+ assert ! ( self . decryption_key. is_none( ) ) ;
124+
125+ Self :: ensure_at_least ( & mut self . stream , & mut self . buf , & mut self . unread_start , 1 ) . await ?;
113126
114127 let byte = self . buf [ self . unread_start ] ;
115128 self . unread_start += 1 ;
116129 Ok ( byte)
117130 }
118131
132+ pub ( crate ) fn set_decryption_key (
133+ & mut self ,
134+ decryption_key : StreamingDecryptingKey ,
135+ integrity_key : hmac:: Key ,
136+ ) {
137+ self . decrypted_buf . clear ( ) ;
138+ self . decryption_key = Some ( ( decryption_key, integrity_key) ) ;
139+ }
140+
119141 pub ( crate ) async fn read_packet < ' a > ( & ' a mut self ) -> Result < Packet < ' a > , Error > {
120142 // Compact the internal buffer
121143 if self . unread_start > 0 {
122144 self . buf . copy_within ( self . unread_start .., 0 ) ;
123145 }
124146 self . buf . truncate ( self . buf . len ( ) - self . unread_start ) ;
147+ self . decrypted_buf . clear ( ) ;
125148 self . unread_start = 0 ;
126149
127- self . ensure_at_least ( 4 ) . await ?;
128- let Decoded {
129- value : packet_length,
130- next,
131- } = PacketLength :: decode ( & self . buf [ self . unread_start ..self . unread_start + 4 ] ) ?;
132- assert ! ( next. is_empty( ) ) ;
133-
134- self . ensure_at_least ( 4 + packet_length. inner ) . await ?;
135- let Decoded {
136- value : packet,
137- next,
138- } = Packet :: decode (
139- & self . buf [ self . unread_start ..self . unread_start + 4 + packet_length. inner as usize ] ,
140- ) ?;
141- assert ! ( next. is_empty( ) ) ;
142-
143- self . unread_start += 4 + packet_length. inner as usize ;
144-
145- Ok ( packet)
150+ let packet_number = self . packet_number ;
151+ self . packet_number = self . packet_number . wrapping_add ( 1 ) ;
152+
153+ if let Some ( ( decrypting_key, integrity_key) ) = & mut self . decryption_key {
154+ let block_len = decrypting_key. algorithm ( ) . block_len ( ) ;
155+
156+ Self :: ensure_at_least (
157+ & mut self . stream ,
158+ & mut self . buf ,
159+ & mut self . unread_start ,
160+ block_len as u32 ,
161+ )
162+ . await ?;
163+ self . decrypted_buf . resize ( self . buf . len ( ) + block_len, 0 ) ;
164+
165+ let update = decrypting_key
166+ . update (
167+ & self . buf [ self . unread_start ..self . unread_start + block_len] ,
168+ & mut self . decrypted_buf [ self . unread_start ..self . unread_start + 2 * block_len] ,
169+ )
170+ . unwrap ( ) ;
171+ assert_eq ! ( update. remainder( ) . len( ) , block_len) ;
172+
173+ let Decoded {
174+ value : packet_length,
175+ next,
176+ } = PacketLength :: decode (
177+ & self . decrypted_buf [ self . unread_start ..self . unread_start + 4 ] ,
178+ ) ?;
179+ assert ! ( next. is_empty( ) ) ;
180+
181+ Self :: ensure_at_least (
182+ & mut self . stream ,
183+ & mut self . buf ,
184+ & mut self . unread_start ,
185+ 4 + packet_length. inner
186+ + integrity_key. algorithm ( ) . digest_algorithm ( ) . output_len as u32 ,
187+ )
188+ . await ?;
189+
190+ let update = decrypting_key
191+ . update (
192+ & self . buf [ self . unread_start + block_len
193+ ..self . unread_start + 4 + packet_length. inner as usize ] ,
194+ & mut self . decrypted_buf [ self . unread_start + block_len
195+ ..self . unread_start + 4 + packet_length. inner as usize + block_len] ,
196+ )
197+ . unwrap ( ) ;
198+ assert_eq ! ( update. remainder( ) . len( ) , block_len) ;
199+
200+ let mut hmac_ctx = hmac:: Context :: with_key ( integrity_key) ;
201+ hmac_ctx. update ( & packet_number. to_be_bytes ( ) ) ;
202+ hmac_ctx. update (
203+ & self . decrypted_buf
204+ [ self . unread_start ..self . unread_start + 4 + packet_length. inner as usize ] ,
205+ ) ;
206+ let actual_mac = hmac_ctx. sign ( ) ;
207+ let expected_mac = & self . buf [ self . unread_start + 4 + packet_length. inner as usize
208+ ..self . unread_start
209+ + 4
210+ + packet_length. inner as usize
211+ + integrity_key. algorithm ( ) . digest_algorithm ( ) . output_len ] ;
212+ constant_time:: verify_slices_are_equal ( actual_mac. as_ref ( ) , expected_mac) . unwrap ( ) ; // FIXME report error
213+
214+ let Decoded {
215+ value : packet,
216+ next,
217+ } = Packet :: decode (
218+ & self . decrypted_buf
219+ [ self . unread_start ..self . unread_start + 4 + packet_length. inner as usize ] ,
220+ ) ?;
221+ assert ! ( next. is_empty( ) ) ;
222+
223+ self . unread_start += 4
224+ + packet_length. inner as usize
225+ + integrity_key. algorithm ( ) . digest_algorithm ( ) . output_len ;
226+
227+ Ok ( packet)
228+ } else {
229+ Self :: ensure_at_least ( & mut self . stream , & mut self . buf , & mut self . unread_start , 4 )
230+ . await ?;
231+ let Decoded {
232+ value : packet_length,
233+ next,
234+ } = PacketLength :: decode ( & self . buf [ self . unread_start ..self . unread_start + 4 ] ) ?;
235+ assert ! ( next. is_empty( ) ) ;
236+
237+ Self :: ensure_at_least (
238+ & mut self . stream ,
239+ & mut self . buf ,
240+ & mut self . unread_start ,
241+ 4 + packet_length. inner ,
242+ )
243+ . await ?;
244+
245+ let Decoded {
246+ value : packet,
247+ next,
248+ } = Packet :: decode (
249+ & self . buf [ self . unread_start ..self . unread_start + 4 + packet_length. inner as usize ] ,
250+ ) ?;
251+ assert ! ( next. is_empty( ) ) ;
252+
253+ self . unread_start += 4 + packet_length. inner as usize ;
254+
255+ Ok ( packet)
256+ }
146257 }
147258}
148259
149260pub ( crate ) struct Packet < ' a > {
150261 pub ( crate ) payload : & ' a [ u8 ] ,
151262}
152263
153- impl Packet < ' _ > {
154- pub ( crate ) fn builder ( buf : & mut Vec < u8 > ) -> PacketBuilder < ' _ > {
264+ impl < ' a > Packet < ' a > {
265+ pub ( crate ) fn builder ( buf : & ' a mut Vec < u8 > ) -> PacketBuilder < ' a > {
155266 let start = buf. len ( ) ;
156267 buf. extend_from_slice ( & [ 0 , 0 , 0 , 0 ] ) ; // packet_length
157268 buf. push ( 0 ) ; // padding_length
0 commit comments