Skip to content

Commit dd5f27d

Browse files
fix: handshake does not fully flush writes (#112)
* write test to verify that flush does not get resumed * attempt to persist the need_flush state during handshake
1 parent 6a775e1 commit dd5f27d

File tree

8 files changed

+201
-52
lines changed

8 files changed

+201
-52
lines changed

src/client.rs

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub struct TlsStream<IO> {
2020
pub(crate) io: IO,
2121
pub(crate) session: ClientConnection,
2222
pub(crate) state: TlsState,
23+
pub(crate) need_flush: bool,
2324

2425
#[cfg(feature = "early-data")]
2526
pub(crate) early_waker: Option<Waker>,
@@ -72,8 +73,13 @@ impl<IO> IoSession for TlsStream<IO> {
7273
}
7374

7475
#[inline]
75-
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
76-
(&mut self.state, &mut self.io, &mut self.session)
76+
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
77+
(
78+
&mut self.state,
79+
&mut self.io,
80+
&mut self.session,
81+
&mut self.need_flush,
82+
)
7783
}
7884

7985
#[inline]
@@ -174,21 +180,27 @@ where
174180
buf: &[u8],
175181
) -> Poll<io::Result<usize>> {
176182
let this = self.get_mut();
177-
let mut stream =
178-
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
183+
let mut stream = Stream::new(&mut this.io, &mut this.session)
184+
.set_eof(!this.state.readable())
185+
.set_need_flush(this.need_flush);
179186

180187
#[cfg(feature = "early-data")]
181188
{
182189
let bufs = [io::IoSlice::new(buf)];
183-
let written = ready!(poll_handle_early_data(
190+
let written = poll_handle_early_data(
184191
&mut this.state,
185192
&mut stream,
186193
&mut this.early_waker,
187194
cx,
188-
&bufs
189-
))?;
190-
if written != 0 {
191-
return Poll::Ready(Ok(written));
195+
&bufs,
196+
)?;
197+
match written {
198+
Poll::Ready(0) => {}
199+
Poll::Ready(written) => return Poll::Ready(Ok(written)),
200+
Poll::Pending => {
201+
this.need_flush = stream.need_flush;
202+
return Poll::Pending;
203+
}
192204
}
193205
}
194206

@@ -203,20 +215,26 @@ where
203215
bufs: &[io::IoSlice<'_>],
204216
) -> Poll<io::Result<usize>> {
205217
let this = self.get_mut();
206-
let mut stream =
207-
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
218+
let mut stream = Stream::new(&mut this.io, &mut this.session)
219+
.set_eof(!this.state.readable())
220+
.set_need_flush(this.need_flush);
208221

209222
#[cfg(feature = "early-data")]
210223
{
211-
let written = ready!(poll_handle_early_data(
224+
let written = poll_handle_early_data(
212225
&mut this.state,
213226
&mut stream,
214227
&mut this.early_waker,
215228
cx,
216-
bufs
217-
))?;
218-
if written != 0 {
219-
return Poll::Ready(Ok(written));
229+
bufs,
230+
)?;
231+
match written {
232+
Poll::Ready(0) => {}
233+
Poll::Ready(written) => return Poll::Ready(Ok(written)),
234+
Poll::Pending => {
235+
this.need_flush = stream.need_flush;
236+
return Poll::Pending;
237+
}
220238
}
221239
}
222240

@@ -230,17 +248,24 @@ where
230248

231249
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232250
let this = self.get_mut();
233-
let mut stream =
234-
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
251+
let mut stream = Stream::new(&mut this.io, &mut this.session)
252+
.set_eof(!this.state.readable())
253+
.set_need_flush(this.need_flush);
235254

236255
#[cfg(feature = "early-data")]
237-
ready!(poll_handle_early_data(
238-
&mut this.state,
239-
&mut stream,
240-
&mut this.early_waker,
241-
cx,
242-
&[]
243-
))?;
256+
{
257+
let written = poll_handle_early_data(
258+
&mut this.state,
259+
&mut stream,
260+
&mut this.early_waker,
261+
cx,
262+
&[],
263+
)?;
264+
if written.is_pending() {
265+
this.need_flush = stream.need_flush;
266+
return Poll::Pending;
267+
}
268+
}
244269

245270
stream.as_mut_pin().poll_flush(cx)
246271
}

src/common/handshake.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub(crate) trait IoSession {
1515
type Session;
1616

1717
fn skip_handshake(&self) -> bool;
18-
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
18+
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool);
1919
fn into_io(self) -> Self::Io;
2020
}
2121

@@ -67,15 +67,18 @@ where
6767
};
6868

6969
if !stream.skip_handshake() {
70-
let (state, io, session) = stream.get_mut();
71-
let mut tls_stream = Stream::new(io, session).set_eof(!state.readable());
70+
let (state, io, session, need_flush) = stream.get_mut();
71+
let mut tls_stream = Stream::new(io, session)
72+
.set_eof(!state.readable())
73+
.set_need_flush(*need_flush);
7274

7375
macro_rules! try_poll {
7476
( $e:expr ) => {
7577
match $e {
76-
Poll::Ready(Ok(_)) => (),
78+
Poll::Ready(Ok(x)) => x,
7779
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
7880
Poll::Pending => {
81+
*need_flush = tls_stream.need_flush;
7982
*this = MidHandshake::Handshaking(stream);
8083
return Poll::Pending;
8184
}

src/common/mod.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ pub(crate) struct Stream<'a, IO, C> {
6363
pub(crate) io: &'a mut IO,
6464
pub(crate) session: &'a mut C,
6565
pub(crate) eof: bool,
66+
pub(crate) need_flush: bool,
6667
}
6768

6869
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
@@ -77,6 +78,8 @@ where
7778
// The state so far is only used to detect EOF, so either Stream
7879
// or EarlyData state should both be all right.
7980
eof: false,
81+
// Whether a previous flush returned pending, or a write occured without a flush.
82+
need_flush: false,
8083
}
8184
}
8285

@@ -85,6 +88,11 @@ where
8588
self
8689
}
8790

91+
pub(crate) fn set_need_flush(mut self, need_flush: bool) -> Self {
92+
self.need_flush = need_flush;
93+
self
94+
}
95+
8896
pub(crate) fn as_mut_pin(&mut self) -> Pin<&mut Self> {
8997
Pin::new(self)
9098
}
@@ -126,14 +134,13 @@ where
126134
loop {
127135
let mut write_would_block = false;
128136
let mut read_would_block = false;
129-
let mut need_flush = false;
130137

131138
while self.session.wants_write() {
132139
match self.write_io(cx) {
133140
Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
134141
Poll::Ready(Ok(n)) => {
135142
wrlen += n;
136-
need_flush = true;
143+
self.need_flush = true;
137144
}
138145
Poll::Pending => {
139146
write_would_block = true;
@@ -143,9 +150,9 @@ where
143150
}
144151
}
145152

146-
if need_flush {
153+
if self.need_flush {
147154
match Pin::new(&mut self.io).poll_flush(cx) {
148-
Poll::Ready(Ok(())) => (),
155+
Poll::Ready(Ok(())) => self.need_flush = false,
149156
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150157
Poll::Pending => write_would_block = true,
151158
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ impl TlsConnector {
149149
TlsState::Stream
150150
},
151151

152+
need_flush: false,
153+
152154
#[cfg(feature = "early-data")]
153155
early_waker: None,
154156

@@ -193,6 +195,7 @@ impl TlsAcceptor {
193195
session,
194196
io: stream,
195197
state: TlsState::Stream,
198+
need_flush: false,
196199
}))
197200
}
198201

@@ -363,6 +366,7 @@ where
363366
session: conn,
364367
io: self.io,
365368
state: TlsState::Stream,
369+
need_flush: false,
366370
}))
367371
}
368372
}

src/server.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub struct TlsStream<IO> {
1818
pub(crate) io: IO,
1919
pub(crate) session: ServerConnection,
2020
pub(crate) state: TlsState,
21+
pub(crate) need_flush: bool,
2122
}
2223

2324
impl<IO> TlsStream<IO> {
@@ -47,8 +48,13 @@ impl<IO> IoSession for TlsStream<IO> {
4748
}
4849

4950
#[inline]
50-
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
51-
(&mut self.state, &mut self.io, &mut self.session)
51+
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
52+
(
53+
&mut self.state,
54+
&mut self.io,
55+
&mut self.session,
56+
&mut self.need_flush,
57+
)
5258
}
5359

5460
#[inline]

tests/early-data.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::thread;
1010
use futures_util::{future::Future, ready};
1111
use rustls::pki_types::ServerName;
1212
use rustls::{self, ClientConfig, ServerConnection, Stream};
13-
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
13+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
1414
use tokio::net::TcpStream;
1515
use tokio_rustls::client::TlsStream;
1616
use tokio_rustls::TlsConnector;
@@ -35,14 +35,15 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
3535
}
3636
}
3737

38-
async fn send(
38+
async fn send<S: AsyncRead + AsyncWrite + Unpin>(
3939
config: Arc<ClientConfig>,
4040
addr: SocketAddr,
41+
wrapper: impl Fn(TcpStream) -> S,
4142
data: &[u8],
4243
vectored: bool,
43-
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> {
44+
) -> io::Result<(TlsStream<S>, Vec<u8>)> {
4445
let connector = TlsConnector::from(config).early_data(true);
45-
let stream = TcpStream::connect(&addr).await?;
46+
let stream = wrapper(TcpStream::connect(&addr).await?);
4647
let domain = ServerName::try_from("foobar.com").unwrap();
4748

4849
let mut stream = connector.connect(domain, stream).await?;
@@ -58,15 +59,23 @@ async fn send(
5859

5960
#[tokio::test]
6061
async fn test_0rtt() -> io::Result<()> {
61-
test_0rtt_impl(false).await
62+
test_0rtt_impl(|s| s, false).await
6263
}
6364

6465
#[tokio::test]
6566
async fn test_0rtt_vectored() -> io::Result<()> {
66-
test_0rtt_impl(true).await
67+
test_0rtt_impl(|s| s, true).await
6768
}
6869

69-
async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
70+
#[tokio::test]
71+
async fn test_0rtt_vectored_flush_pending() -> io::Result<()> {
72+
test_0rtt_impl(utils::FlushWrapper::new, false).await
73+
}
74+
75+
async fn test_0rtt_impl<S: AsyncRead + AsyncWrite + Unpin>(
76+
wrapper: impl Fn(TcpStream) -> S,
77+
vectored: bool,
78+
) -> io::Result<()> {
7079
let (mut server, mut client) = utils::make_configs();
7180
server.max_early_data_size = 8192;
7281
let server = Arc::new(server);
@@ -108,11 +117,11 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
108117
let client = Arc::new(client);
109118
let addr = SocketAddr::from(([127, 0, 0, 1], server_port));
110119

111-
let (io, buf) = send(client.clone(), addr, b"hello", vectored).await?;
120+
let (io, buf) = send(client.clone(), addr, &wrapper, b"hello", vectored).await?;
112121
assert!(!io.get_ref().1.is_early_data_accepted());
113122
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));
114123

115-
let (io, buf) = send(client, addr, b"world!", vectored).await?;
124+
let (io, buf) = send(client, addr, wrapper, b"world!", vectored).await?;
116125
assert!(io.get_ref().1.is_early_data_accepted());
117126
assert_eq!("EARLY:world!LATE:", String::from_utf8_lossy(&buf));
118127

0 commit comments

Comments
 (0)