@@ -70,7 +70,8 @@ impl TlsAcceptor {
7070pub struct LazyConfigAcceptor < IO > {
7171 acceptor : rustls:: server:: Acceptor ,
7272 io : Option < IO > ,
73- alert : Option < ( rustls:: Error , AcceptedAlert ) > ,
73+ alert : Option < AlertState > ,
74+ send_alert : bool ,
7475}
7576
7677impl < IO > LazyConfigAcceptor < IO >
8384 acceptor,
8485 io : Some ( io) ,
8586 alert : None ,
87+ send_alert : true ,
8688 }
8789 }
8890
91+ /// Configure whether to send a TLS alert on failure.
92+ pub fn send_alert ( mut self , send : bool ) -> Self {
93+ self . send_alert = send;
94+ self
95+ }
96+
97+ /// Writes a stored alert, consuming the alert (if any) and IO.
98+ pub async fn write_alert ( & mut self ) -> io:: Result < ( ) > {
99+ let Some ( alert) = self . take_alert ( ) else {
100+ return Ok ( ( ) ) ;
101+ } ;
102+
103+ let Some ( io) = self . take_io ( ) else {
104+ return Ok ( ( ) ) ;
105+ } ;
106+
107+ WritingAlert {
108+ io,
109+ alert : Some ( alert) ,
110+ }
111+ . await
112+ }
113+
89114 /// Takes back the client connection. Will return `None` if called more than once or if the
90115 /// connection has been accepted.
91116 ///
@@ -130,6 +155,14 @@ where
130155 pub fn take_io ( & mut self ) -> Option < IO > {
131156 self . io . take ( )
132157 }
158+
159+ pub fn take_alert ( & mut self ) -> Option < AcceptedAlert > {
160+ match self . alert . take ( ) {
161+ Some ( AlertState :: Sending ( _, alert) ) => Some ( alert) ,
162+ Some ( AlertState :: Saved ( alert) ) => Some ( alert) ,
163+ None => None ,
164+ }
165+ }
133166}
134167
135168impl < IO > Future for LazyConfigAcceptor < IO >
@@ -151,17 +184,17 @@ where
151184 }
152185 } ;
153186
154- if let Some ( ( err, mut alert) ) = this. alert . take ( ) {
187+ if let Some ( AlertState :: Sending ( err, mut alert) ) = this. alert . take ( ) {
155188 match alert. write ( & mut SyncWriteAdapter { io, cx } ) {
156189 Err ( e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => {
157- this. alert = Some ( ( err, alert) ) ;
190+ this. alert = Some ( AlertState :: Sending ( err, alert) ) ;
158191 return Poll :: Pending ;
159192 }
160193 Ok ( 0 ) | Err ( _) => {
161194 return Poll :: Ready ( Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , err) ) )
162195 }
163196 Ok ( _) => {
164- this. alert = Some ( ( err, alert) ) ;
197+ this. alert = Some ( AlertState :: Sending ( err, alert) ) ;
165198 continue ;
166199 }
167200 } ;
@@ -181,9 +214,49 @@ where
181214 return Poll :: Ready ( Ok ( StartHandshake { accepted, io } ) ) ;
182215 }
183216 Ok ( None ) => { }
184- Err ( ( err, alert) ) => {
185- this. alert = Some ( ( err, alert) ) ;
217+ Err ( ( err, alert) ) => match this. send_alert {
218+ true => this. alert = Some ( AlertState :: Sending ( err, alert) ) ,
219+ false => {
220+ this. alert = Some ( AlertState :: Saved ( alert) ) ;
221+ return Poll :: Ready ( Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , err) ) ) ;
222+ }
223+ } ,
224+ }
225+ }
226+ }
227+ }
228+
229+ enum AlertState {
230+ Sending ( rustls:: Error , AcceptedAlert ) ,
231+ Saved ( AcceptedAlert ) ,
232+ }
233+
234+ struct WritingAlert < IO > {
235+ io : IO ,
236+ alert : Option < AcceptedAlert > ,
237+ }
238+
239+ impl < IO > Future for WritingAlert < IO >
240+ where
241+ IO : AsyncRead + AsyncWrite + Unpin ,
242+ {
243+ type Output = Result < ( ) , io:: Error > ;
244+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
245+ let this = self . get_mut ( ) ;
246+ let io = & mut this. io ;
247+ loop {
248+ let Some ( mut alert) = this. alert . take ( ) else {
249+ return Poll :: Ready ( Ok ( ( ) ) ) ;
250+ } ;
251+
252+ match alert. write ( & mut SyncWriteAdapter { io, cx } ) {
253+ Ok ( 0 ) => return Poll :: Ready ( Ok ( ( ) ) ) ,
254+ Ok ( _) => continue ,
255+ Err ( e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => {
256+ this. alert = Some ( alert) ;
257+ return Poll :: Pending ;
186258 }
259+ Err ( e) => return Poll :: Ready ( Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ) ,
187260 }
188261 }
189262 }
0 commit comments