1- use std:: future:: Future ;
1+ use std:: future:: { poll_fn , Future } ;
22use std:: io:: { self , BufRead as _} ;
33#[ cfg( unix) ]
44use std:: os:: unix:: io:: { AsRawFd , RawFd } ;
@@ -10,7 +10,7 @@ use std::task::{Context, Poll};
1010
1111use rustls:: server:: AcceptedAlert ;
1212use rustls:: { ServerConfig , ServerConnection } ;
13- use tokio:: io:: { AsyncBufRead , AsyncRead , AsyncWrite , ReadBuf } ;
13+ use tokio:: io:: { AsyncBufRead , AsyncRead , AsyncWrite , AsyncWriteExt , ReadBuf } ;
1414
1515use crate :: common:: { IoSession , MidHandshake , Stream , SyncReadAdapter , SyncWriteAdapter , TlsState } ;
1616
@@ -111,7 +111,7 @@ where
111111 /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
112112 /// let (stream, _) = listener.accept().await.unwrap();
113113 ///
114- /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
114+ /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream).send_alert(false) ;
115115 /// tokio::pin!(acceptor);
116116 ///
117117 /// match acceptor.as_mut().await {
@@ -146,6 +146,57 @@ where
146146 None => None ,
147147 }
148148 }
149+
150+ /// Writes a stored alert, consuming the alert (if any) and IO.
151+ pub async fn write_alert ( & mut self ) -> io:: Result < ( ) > {
152+ let Some ( alert) = self . take_alert ( ) else {
153+ return Ok ( ( ) ) ;
154+ } ;
155+ let Some ( io) = self . take_io ( ) else {
156+ return Ok ( ( ) ) ;
157+ } ;
158+ WritingAlert {
159+ io,
160+ alert : Some ( alert) ,
161+ }
162+ . await
163+ }
164+ }
165+
166+ struct WritingAlert < IO > {
167+ io : IO ,
168+ alert : Option < AcceptedAlert > ,
169+ }
170+
171+ impl < IO > Future for WritingAlert < IO >
172+ where
173+ IO : AsyncRead + AsyncWrite + Unpin ,
174+ {
175+ type Output = Result < ( ) , io:: Error > ;
176+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
177+ let this = self . get_mut ( ) ;
178+ let io = & mut this. io ;
179+ loop {
180+ if let Some ( mut alert) = this. alert . take ( ) {
181+ match alert. write ( & mut SyncWriteAdapter { io, cx } ) {
182+ Err ( e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => {
183+ this. alert = Some ( alert) ;
184+ return Poll :: Pending ;
185+ }
186+ Err ( e) => {
187+ return Poll :: Ready ( Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ) ;
188+ }
189+ Ok ( 0 ) => {
190+ return Poll :: Ready ( Ok ( ( ) ) ) ;
191+ }
192+ Ok ( n) => {
193+ this. alert = Some ( alert) ;
194+ continue ;
195+ }
196+ } ;
197+ }
198+ }
199+ }
149200}
150201
151202impl < IO > Future for LazyConfigAcceptor < IO >
@@ -199,7 +250,10 @@ where
199250 Ok ( None ) => { }
200251 Err ( ( err, alert) ) => match this. send_alert {
201252 true => this. alert = Some ( AlertState :: Sending ( err, alert) ) ,
202- false => this. alert = Some ( AlertState :: Saved ( alert) ) ,
253+ false => {
254+ this. alert = Some ( AlertState :: Saved ( alert) ) ;
255+ return Poll :: Ready ( Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , err) ) ) ;
256+ }
203257 } ,
204258 }
205259 }
0 commit comments