@@ -6,7 +6,11 @@ use super::stream::write_packets;
66
77use crate :: io:: Decode ;
88use bytes:: Bytes ;
9- use sqlx_rt:: { AsyncRead , AsyncWrite , ReadBuf } ;
9+ use sqlx_rt:: { AsyncRead , AsyncWrite } ;
10+
11+ #[ cfg( feature = "_rt-tokio" ) ]
12+ use sqlx_rt:: ReadBuf ;
13+
1014use std:: cmp;
1115use std:: io;
1216use std:: pin:: Pin ;
@@ -72,12 +76,19 @@ impl<S> TlsPreloginWrapper<S> {
7276 }
7377}
7478
79+ #[ cfg( feature = "_rt-async-std" ) ]
80+ type PollReadOut = usize ;
81+
82+ #[ cfg( feature = "_rt-tokio" ) ]
83+ type PollReadOut = ( ) ;
84+
7585impl < S : AsyncRead + AsyncWrite + Unpin + Send > AsyncRead for TlsPreloginWrapper < S > {
86+ #[ cfg( feature = "_rt-tokio" ) ]
7687 fn poll_read (
7788 mut self : Pin < & mut Self > ,
7889 cx : & mut task:: Context < ' _ > ,
7990 buf : & mut ReadBuf < ' _ > ,
80- ) -> Poll < io:: Result < ( ) > > {
91+ ) -> Poll < io:: Result < PollReadOut > > {
8192 if !self . pending_handshake {
8293 return Pin :: new ( & mut self . stream ) . poll_read ( cx, buf) ;
8394 }
@@ -91,7 +102,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
91102
92103 let read = header_buf. filled ( ) . len ( ) ;
93104 if read == 0 {
94- return Poll :: Ready ( Ok ( ( ) ) ) ;
105+ return Poll :: Ready ( Ok ( PollReadOut :: default ( ) ) ) ;
95106 }
96107
97108 inner. header_pos += read;
@@ -112,7 +123,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
112123 let max_read = std:: cmp:: min ( inner. read_remaining , buf. remaining ( ) ) ;
113124 let mut limited_buf = buf. take ( max_read) ;
114125
115- ready ! ( Pin :: new( & mut inner. stream) . poll_read( cx, & mut limited_buf) ) ?;
126+ let res = ready ! ( Pin :: new( & mut inner. stream) . poll_read( cx, & mut limited_buf) ) ?;
116127
117128 let read = limited_buf. filled ( ) . len ( ) ;
118129 buf. advance ( read) ;
@@ -122,7 +133,20 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
122133 inner. header_pos = 0 ;
123134 }
124135
125- Poll :: Ready ( Ok ( ( ) ) )
136+ Poll :: Ready ( Ok ( res) )
137+ }
138+
139+ #[ cfg( feature = "_rt-async-std" ) ]
140+ fn poll_read (
141+ mut self : Pin < & mut Self > ,
142+ cx : & mut task:: Context < ' _ > ,
143+ buf : & mut [ u8 ] ,
144+ ) -> Poll < io:: Result < usize > > {
145+ if self . pending_handshake {
146+ panic ! ( "TLS not supported on async-std for mssql" ) ;
147+ }
148+
149+ Pin :: new ( & mut self . stream ) . poll_read ( cx, buf)
126150 }
127151}
128152
@@ -174,7 +198,29 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper
174198 Pin :: new ( & mut inner. stream ) . poll_flush ( cx)
175199 }
176200
201+ #[ cfg( feature = "_rt-tokio" ) ]
177202 fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
178203 Pin :: new ( & mut self . stream ) . poll_shutdown ( cx)
179204 }
205+
206+ #[ cfg( feature = "_rt-async-std" ) ]
207+ fn poll_close ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
208+ Pin :: new ( & mut self . stream ) . poll_close ( cx)
209+ }
210+ }
211+
212+ use std:: ops:: { Deref , DerefMut } ;
213+
214+ impl < S > Deref for TlsPreloginWrapper < S > {
215+ type Target = S ;
216+
217+ fn deref ( & self ) -> & Self :: Target {
218+ & self . stream
219+ }
220+ }
221+
222+ impl < S > DerefMut for TlsPreloginWrapper < S > {
223+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
224+ & mut self . stream
225+ }
180226}
0 commit comments