Skip to content

Commit 710cf25

Browse files
authored
Implement AsyncBufRead (#100)
* Implement AsyncBufRead for TlsStream types & Stream * Implement AsyncRead using AsyncBufRead * Reimplement AsyncRead for {server,client}::TlsStream in terms of AsyncBufRead
1 parent 276625b commit 710cf25

File tree

8 files changed

+211
-99
lines changed

8 files changed

+211
-99
lines changed

Cargo.lock

Lines changed: 10 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ rust-version = "1.71"
1313
exclude = ["/.github", "/examples", "/scripts"]
1414

1515
[dependencies]
16-
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
16+
rustls = { version = "0.23.22", default-features = false, features = ["std"] }
1717
tokio = "1.0"
1818

1919
[features]

src/client.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io;
1+
use std::io::{self, BufRead as _};
22
#[cfg(unix)]
33
use std::os::unix::io::{AsRawFd, RawFd};
44
#[cfg(windows)]
@@ -9,7 +9,7 @@ use std::task::Waker;
99
use std::task::{Context, Poll};
1010

1111
use rustls::ClientConnection;
12-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12+
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
1313

1414
use crate::common::{IoSession, Stream, TlsState};
1515

@@ -82,50 +82,69 @@ impl<IO> IoSession for TlsStream<IO> {
8282
}
8383
}
8484

85+
#[cfg(feature = "early-data")]
86+
impl<IO> TlsStream<IO>
87+
where
88+
IO: AsyncRead + AsyncWrite + Unpin,
89+
{
90+
fn poll_early_data(&mut self, cx: &mut Context<'_>) {
91+
// In the EarlyData state, we have not really established a Tls connection.
92+
// Before writing data through `AsyncWrite` and completing the tls handshake,
93+
// we ignore read readiness and return to pending.
94+
//
95+
// In order to avoid event loss,
96+
// we need to register a waker and wake it up after tls is connected.
97+
if self
98+
.early_waker
99+
.as_ref()
100+
.filter(|waker| cx.waker().will_wake(waker))
101+
.is_none()
102+
{
103+
self.early_waker = Some(cx.waker().clone());
104+
}
105+
}
106+
}
107+
85108
impl<IO> AsyncRead for TlsStream<IO>
86109
where
87110
IO: AsyncRead + AsyncWrite + Unpin,
88111
{
89112
fn poll_read(
90-
self: Pin<&mut Self>,
113+
mut self: Pin<&mut Self>,
91114
cx: &mut Context<'_>,
92115
buf: &mut ReadBuf<'_>,
93116
) -> Poll<io::Result<()>> {
117+
let data = ready!(self.as_mut().poll_fill_buf(cx))?;
118+
let len = data.len().min(buf.remaining());
119+
buf.put_slice(&data[..len]);
120+
self.consume(len);
121+
Poll::Ready(Ok(()))
122+
}
123+
}
124+
125+
impl<IO> AsyncBufRead for TlsStream<IO>
126+
where
127+
IO: AsyncRead + AsyncWrite + Unpin,
128+
{
129+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
94130
match self.state {
95131
#[cfg(feature = "early-data")]
96132
TlsState::EarlyData(..) => {
97-
let this = self.get_mut();
98-
99-
// In the EarlyData state, we have not really established a Tls connection.
100-
// Before writing data through `AsyncWrite` and completing the tls handshake,
101-
// we ignore read readiness and return to pending.
102-
//
103-
// In order to avoid event loss,
104-
// we need to register a waker and wake it up after tls is connected.
105-
if this
106-
.early_waker
107-
.as_ref()
108-
.filter(|waker| cx.waker().will_wake(waker))
109-
.is_none()
110-
{
111-
this.early_waker = Some(cx.waker().clone());
112-
}
113-
133+
self.get_mut().poll_early_data(cx);
114134
Poll::Pending
115135
}
116136
TlsState::Stream | TlsState::WriteShutdown => {
117137
let this = self.get_mut();
118-
let mut stream =
138+
let stream =
119139
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120-
let prev = buf.remaining();
121140

122-
match stream.as_mut_pin().poll_read(cx, buf) {
123-
Poll::Ready(Ok(())) => {
124-
if prev == buf.remaining() || stream.eof {
141+
match stream.poll_fill_buf(cx) {
142+
Poll::Ready(Ok(buf)) => {
143+
if buf.is_empty() {
125144
this.state.shutdown_read();
126145
}
127146

128-
Poll::Ready(Ok(()))
147+
Poll::Ready(Ok(buf))
129148
}
130149
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
131150
this.state.shutdown_read();
@@ -134,9 +153,13 @@ where
134153
output => output,
135154
}
136155
}
137-
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
156+
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
138157
}
139158
}
159+
160+
fn consume(mut self: Pin<&mut Self>, amt: usize) {
161+
self.session.reader().consume(amt);
162+
}
140163
}
141164

142165
impl<IO> AsyncWrite for TlsStream<IO>

src/common/mod.rs

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use std::io::{self, IoSlice, Read, Write};
1+
use std::io::{self, BufRead as _, IoSlice, Read, Write};
22
use std::ops::{Deref, DerefMut};
33
use std::pin::Pin;
44
use std::task::{Context, Poll};
55

66
use rustls::{ConnectionCommon, SideData};
7-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7+
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
88

99
mod handshake;
1010
pub(crate) use handshake::{IoSession, MidHandshake};
@@ -180,18 +180,11 @@ where
180180
};
181181
}
182182
}
183-
}
184183

185-
impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'_, IO, C>
186-
where
187-
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
188-
SD: SideData,
189-
{
190-
fn poll_read(
191-
mut self: Pin<&mut Self>,
192-
cx: &mut Context<'_>,
193-
buf: &mut ReadBuf<'_>,
194-
) -> Poll<io::Result<()>> {
184+
pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>>
185+
where
186+
SD: 'a,
187+
{
195188
let mut io_pending = false;
196189

197190
// read a packet
@@ -209,22 +202,13 @@ where
209202
}
210203
}
211204

212-
match self.session.reader().read(buf.initialize_unfilled()) {
213-
// If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
214-
// connection with a `CloseNotify` message and no more data will be forthcoming.
215-
//
216-
// Rustls yielded more data: advance the buffer, then see if more data is coming.
217-
//
218-
// We don't need to modify `self.eof` here, because it is only a temporary mark.
219-
// rustls will only return 0 if is has received `CloseNotify`,
220-
// in which case no additional processing is required.
221-
Ok(n) => {
222-
buf.advance(n);
223-
Poll::Ready(Ok(()))
205+
match self.session.reader().into_first_chunk() {
206+
Ok(buf) => {
207+
// Note that this could be empty (i.e. EOF) if a `CloseNotify` has been
208+
// received and there is no more buffered data.
209+
Poll::Ready(Ok(buf))
224210
}
225-
226-
// Rustls doesn't have more data to yield, but it believes the connection is open.
227-
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
211+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
228212
if !io_pending {
229213
// If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
230214
// but if it does, we can try again.
@@ -236,9 +220,47 @@ where
236220

237221
Poll::Pending
238222
}
223+
Err(e) => Poll::Ready(Err(e)),
224+
}
225+
}
226+
}
239227

240-
Err(err) => Poll::Ready(Err(err)),
228+
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
229+
where
230+
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
231+
SD: SideData + 'a,
232+
{
233+
fn poll_read(
234+
mut self: Pin<&mut Self>,
235+
cx: &mut Context<'_>,
236+
buf: &mut ReadBuf<'_>,
237+
) -> Poll<io::Result<()>> {
238+
let data = ready!(self.as_mut().poll_fill_buf(cx))?;
239+
let amount = buf.remaining().min(data.len());
240+
buf.put_slice(&data[..amount]);
241+
self.session.reader().consume(amount);
242+
Poll::Ready(Ok(()))
243+
}
244+
}
245+
246+
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C>
247+
where
248+
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
249+
SD: SideData + 'a,
250+
{
251+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
252+
let this = self.get_mut();
253+
Stream {
254+
// reborrow
255+
io: this.io,
256+
session: this.session,
257+
..*this
241258
}
259+
.poll_fill_buf(cx)
260+
}
261+
262+
fn consume(mut self: Pin<&mut Self>, amt: usize) {
263+
self.session.reader().consume(amt);
242264
}
243265
}
244266

src/common/test_stream.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,20 @@ impl AsyncWrite for Eof {
154154

155155
#[tokio::test]
156156
async fn stream_good() -> io::Result<()> {
157-
stream_good_impl(false).await
157+
stream_good_impl(false, false).await
158158
}
159159

160160
#[tokio::test]
161161
async fn stream_good_vectored() -> io::Result<()> {
162-
stream_good_impl(true).await
162+
stream_good_impl(true, false).await
163163
}
164164

165-
async fn stream_good_impl(vectored: bool) -> io::Result<()> {
165+
#[tokio::test]
166+
async fn stream_good_bufread() -> io::Result<()> {
167+
stream_good_impl(false, true).await
168+
}
169+
170+
async fn stream_good_impl(vectored: bool, bufread: bool) -> io::Result<()> {
166171
const FILE: &[u8] = include_bytes!("../../README.md");
167172

168173
let (server, mut client) = make_pair();
@@ -177,7 +182,11 @@ async fn stream_good_impl(vectored: bool) -> io::Result<()> {
177182
let mut stream = Stream::new(&mut good, &mut client);
178183

179184
let mut buf = Vec::new();
180-
dbg!(stream.read_to_end(&mut buf).await)?;
185+
if bufread {
186+
dbg!(tokio::io::copy_buf(&mut stream, &mut buf).await)?;
187+
} else {
188+
dbg!(stream.read_to_end(&mut buf).await)?;
189+
}
181190
assert_eq!(buf, FILE);
182191

183192
dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?;

src/lib.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub use rustls;
5151
use rustls::pki_types::ServerName;
5252
use rustls::server::AcceptedAlert;
5353
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
54+
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
5555

5656
macro_rules! ready {
5757
( $e:expr ) => {
@@ -545,6 +545,27 @@ where
545545
}
546546
}
547547

548+
impl<T> AsyncBufRead for TlsStream<T>
549+
where
550+
T: AsyncRead + AsyncWrite + Unpin,
551+
{
552+
#[inline]
553+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
554+
match self.get_mut() {
555+
TlsStream::Client(x) => Pin::new(x).poll_fill_buf(cx),
556+
TlsStream::Server(x) => Pin::new(x).poll_fill_buf(cx),
557+
}
558+
}
559+
560+
#[inline]
561+
fn consume(self: Pin<&mut Self>, amt: usize) {
562+
match self.get_mut() {
563+
TlsStream::Client(x) => Pin::new(x).consume(amt),
564+
TlsStream::Server(x) => Pin::new(x).consume(amt),
565+
}
566+
}
567+
}
568+
548569
impl<T> AsyncWrite for TlsStream<T>
549570
where
550571
T: AsyncRead + AsyncWrite + Unpin,

0 commit comments

Comments
 (0)