Skip to content

Commit f430ef1

Browse files
committed
server: enable holding back acceptor alerts
1 parent 7ac70c1 commit f430ef1

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

src/server.rs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ impl TlsAcceptor {
7070
pub struct LazyConfigAcceptor<IO> {
7171
acceptor: rustls::server::Acceptor,
7272
io: Option<IO>,
73-
alert: Option<(rustls::Error, AcceptedAlert)>,
73+
alert: Option<AlertState>,
74+
send_alert: bool,
7475
}
7576

7677
impl<IO> LazyConfigAcceptor<IO>
@@ -83,9 +84,16 @@ where
8384
acceptor,
8485
io: Some(io),
8586
alert: None,
87+
send_alert: true,
8688
}
8789
}
8890

91+
/// Configure whether to send a TLS alert on failure.
92+
pub fn send_alert(mut self, send: bool) -> Self {
93+
self.send_alert = send;
94+
self
95+
}
96+
8997
/// Takes back the client connection. Will return `None` if called more than once or if the
9098
/// connection has been accepted.
9199
///
@@ -130,6 +138,14 @@ where
130138
pub fn take_io(&mut self) -> Option<IO> {
131139
self.io.take()
132140
}
141+
142+
pub fn take_alert(&mut self) -> Option<AcceptedAlert> {
143+
match self.alert.take() {
144+
Some(AlertState::Sending(_, alert)) => Some(alert),
145+
Some(AlertState::Saved(alert)) => Some(alert),
146+
None => None,
147+
}
148+
}
133149
}
134150

135151
impl<IO> Future for LazyConfigAcceptor<IO>
@@ -151,17 +167,17 @@ where
151167
}
152168
};
153169

154-
if let Some((err, mut alert)) = this.alert.take() {
170+
if let Some(AlertState::Sending(err, mut alert)) = this.alert.take() {
155171
match alert.write(&mut SyncWriteAdapter { io, cx }) {
156172
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
157-
this.alert = Some((err, alert));
173+
this.alert = Some(AlertState::Sending(err, alert));
158174
return Poll::Pending;
159175
}
160176
Ok(0) | Err(_) => {
161177
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
162178
}
163179
Ok(_) => {
164-
this.alert = Some((err, alert));
180+
this.alert = Some(AlertState::Sending(err, alert));
165181
continue;
166182
}
167183
};
@@ -181,14 +197,20 @@ where
181197
return Poll::Ready(Ok(StartHandshake { accepted, io }));
182198
}
183199
Ok(None) => {}
184-
Err((err, alert)) => {
185-
this.alert = Some((err, alert));
186-
}
200+
Err((err, alert)) => match this.send_alert {
201+
true => this.alert = Some(AlertState::Sending(err, alert)),
202+
false => this.alert = Some(AlertState::Saved(alert)),
203+
},
187204
}
188205
}
189206
}
190207
}
191208

209+
enum AlertState {
210+
Sending(rustls::Error, AcceptedAlert),
211+
Saved(AcceptedAlert),
212+
}
213+
192214
/// An incoming connection received through [`LazyConfigAcceptor`].
193215
///
194216
/// This contains the generic `IO` asynchronous transport,

0 commit comments

Comments
 (0)