@@ -142,11 +142,48 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
142142 cx : & mut task:: Context < ' _ > ,
143143 buf : & mut [ u8 ] ,
144144 ) -> Poll < io:: Result < usize > > {
145- if self . pending_handshake {
146- unimplemented ! ( "TLS not supported on async-std for mssql" ) ;
145+ if !self . pending_handshake {
146+ return Pin :: new ( & mut self . stream ) . poll_read ( cx, buf) ;
147+ }
148+
149+ let inner = self . get_mut ( ) ;
150+
151+ if !inner. header_buf [ inner. header_pos ..] . is_empty ( ) {
152+ while !inner. header_buf [ inner. header_pos ..] . is_empty ( ) {
153+ let header_buf = & mut inner. header_buf [ inner. header_pos ..] ;
154+ let read = ready ! ( Pin :: new( & mut inner. stream) . poll_read( cx, header_buf) ) ?;
155+
156+ if read == 0 {
157+ return Poll :: Ready ( Ok ( PollReadOut :: default ( ) ) ) ;
158+ }
159+
160+ inner. header_pos += read;
161+ }
162+
163+ let header: PacketHeader = Decode :: decode ( Bytes :: copy_from_slice ( & inner. header_buf ) )
164+ . map_err ( |err| io:: Error :: new ( io:: ErrorKind :: Other , err) ) ?;
165+
166+ inner. read_remaining = usize:: from ( header. length ) - HEADER_BYTES ;
167+
168+ log:: trace!(
169+ "Discarding header ({:?}), reading packet of {} bytes" ,
170+ header,
171+ inner. read_remaining,
172+ ) ;
173+ }
174+
175+ let max_read = std:: cmp:: min ( inner. read_remaining , buf. len ( ) ) ;
176+ let limited_buf = & mut buf[ ..max_read] ;
177+
178+ let read = ready ! ( Pin :: new( & mut inner. stream) . poll_read( cx, limited_buf) ) ?;
179+
180+ inner. read_remaining -= read;
181+
182+ if inner. read_remaining == 0 {
183+ inner. header_pos = 0 ;
147184 }
148185
149- Pin :: new ( & mut self . stream ) . poll_read ( cx , buf )
186+ Poll :: Ready ( Ok ( read ) )
150187 }
151188}
152189
0 commit comments