Skip to content

Commit e6e3525

Browse files
djchowardjohn
andcommitted
server: enable holding back acceptor alerts
Co-authored-by: John Howard <[email protected]>
1 parent 4d79ba7 commit e6e3525

File tree

1 file changed

+79
-6
lines changed

1 file changed

+79
-6
lines changed

src/server.rs

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ impl TlsAcceptor {
7070
pub struct LazyConfigAcceptor<IO> {
7171
acceptor: rustls::server::Acceptor,
7272
io: Option<IO>,
73-
alert: Option<(rustls::Error, AcceptedAlert)>,
73+
alert: Option<AlertState>,
74+
send_alert: bool,
7475
}
7576

7677
impl<IO> LazyConfigAcceptor<IO>
@@ -83,9 +84,33 @@ where
8384
acceptor,
8485
io: Some(io),
8586
alert: None,
87+
send_alert: true,
8688
}
8789
}
8890

91+
/// Configure whether to send a TLS alert on failure.
92+
pub fn send_alert(mut self, send: bool) -> Self {
93+
self.send_alert = send;
94+
self
95+
}
96+
97+
/// Writes a stored alert, consuming the alert (if any) and IO.
98+
pub async fn write_alert(&mut self) -> io::Result<()> {
99+
let Some(alert) = self.take_alert() else {
100+
return Ok(());
101+
};
102+
103+
let Some(io) = self.take_io() else {
104+
return Ok(());
105+
};
106+
107+
WritingAlert {
108+
io,
109+
alert: Some(alert),
110+
}
111+
.await
112+
}
113+
89114
/// Takes back the client connection. Will return `None` if called more than once or if the
90115
/// connection has been accepted.
91116
///
@@ -130,6 +155,14 @@ where
130155
pub fn take_io(&mut self) -> Option<IO> {
131156
self.io.take()
132157
}
158+
159+
pub fn take_alert(&mut self) -> Option<AcceptedAlert> {
160+
match self.alert.take() {
161+
Some(AlertState::Sending(_, alert)) => Some(alert),
162+
Some(AlertState::Saved(alert)) => Some(alert),
163+
None => None,
164+
}
165+
}
133166
}
134167

135168
impl<IO> Future for LazyConfigAcceptor<IO>
@@ -151,17 +184,17 @@ where
151184
}
152185
};
153186

154-
if let Some((err, mut alert)) = this.alert.take() {
187+
if let Some(AlertState::Sending(err, mut alert)) = this.alert.take() {
155188
match alert.write(&mut SyncWriteAdapter { io, cx }) {
156189
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
157-
this.alert = Some((err, alert));
190+
this.alert = Some(AlertState::Sending(err, alert));
158191
return Poll::Pending;
159192
}
160193
Ok(0) | Err(_) => {
161194
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
162195
}
163196
Ok(_) => {
164-
this.alert = Some((err, alert));
197+
this.alert = Some(AlertState::Sending(err, alert));
165198
continue;
166199
}
167200
};
@@ -181,9 +214,49 @@ where
181214
return Poll::Ready(Ok(StartHandshake { accepted, io }));
182215
}
183216
Ok(None) => {}
184-
Err((err, alert)) => {
185-
this.alert = Some((err, alert));
217+
Err((err, alert)) => match this.send_alert {
218+
true => this.alert = Some(AlertState::Sending(err, alert)),
219+
false => {
220+
this.alert = Some(AlertState::Saved(alert));
221+
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)));
222+
}
223+
},
224+
}
225+
}
226+
}
227+
}
228+
229+
enum AlertState {
230+
Sending(rustls::Error, AcceptedAlert),
231+
Saved(AcceptedAlert),
232+
}
233+
234+
struct WritingAlert<IO> {
235+
io: IO,
236+
alert: Option<AcceptedAlert>,
237+
}
238+
239+
impl<IO> Future for WritingAlert<IO>
240+
where
241+
IO: AsyncRead + AsyncWrite + Unpin,
242+
{
243+
type Output = Result<(), io::Error>;
244+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245+
let this = self.get_mut();
246+
let io = &mut this.io;
247+
loop {
248+
let Some(mut alert) = this.alert.take() else {
249+
return Poll::Ready(Ok(()));
250+
};
251+
252+
match alert.write(&mut SyncWriteAdapter { io, cx }) {
253+
Ok(0) => return Poll::Ready(Ok(())),
254+
Ok(_) => continue,
255+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
256+
this.alert = Some(alert);
257+
return Poll::Pending;
186258
}
259+
Err(e) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, e))),
187260
}
188261
}
189262
}

0 commit comments

Comments
 (0)