diff --git a/tokio-util/src/io/sync_bridge.rs b/tokio-util/src/io/sync_bridge.rs index 6821c7ede71..b9566db6a7a 100644 --- a/tokio-util/src/io/sync_bridge.rs +++ b/tokio-util/src/io/sync_bridge.rs @@ -279,6 +279,12 @@ impl BufRead for SyncIoBridge { self.rt .block_on(AsyncBufReadExt::read_until(src, byte, buf)) } + + fn skip_until(&mut self, byte: u8) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(AsyncBufReadExt::skip_until(src, byte)) + } + fn read_line(&mut self, buf: &mut String) -> std::io::Result { let src = &mut self.src; self.rt.block_on(AsyncBufReadExt::read_line(src, buf)) diff --git a/tokio/src/io/util/async_buf_read_ext.rs b/tokio/src/io/util/async_buf_read_ext.rs index 1e9da4c8c4d..ee313c8b503 100644 --- a/tokio/src/io/util/async_buf_read_ext.rs +++ b/tokio/src/io/util/async_buf_read_ext.rs @@ -2,6 +2,7 @@ use crate::io::util::fill_buf::{fill_buf, FillBuf}; use crate::io::util::lines::{lines, Lines}; use crate::io::util::read_line::{read_line, ReadLine}; use crate::io::util::read_until::{read_until, ReadUntil}; +use crate::io::util::skip_until::{skip_until, SkipUntil}; use crate::io::util::split::{split, Split}; use crate::io::AsyncBufRead; @@ -100,6 +101,86 @@ cfg_io_util! { read_until(self, byte, buf) } + /// Skips all bytes until the delimiter `byte` or EOF is reached. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn skip_until(&mut self, byte: u8) -> io::Result; + /// ``` + /// + /// This function will read bytes from the underlying stream until the + /// delimiter or EOF is found, discarding all bytes read. + /// + /// If successful, this function will return the total number of bytes read. + /// + /// If this function returns `Ok(0)`, the stream has reached EOF. + /// + /// # Errors + /// + /// This function will ignore all instances of [`ErrorKind::Interrupted`] and + /// will otherwise return any errors returned by [`fill_buf`]. + /// + /// [`fill_buf`]: AsyncBufRead::poll_fill_buf + /// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted + /// + /// # Cancel safety + /// + /// If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then some data may have been partially read. Any + /// partially read bytes are skipped, and the method can be called again + /// to continue reading until `byte`. + /// + /// This method returns the total number of bytes read. If you cancel + /// the call to `skip_until` and then call it again to continue reading, + /// the counter is reset. + /// + /// # Examples + /// + /// [`std::io::Cursor`][`Cursor`] is a type that implements `BufRead`. In + /// this example, we use [`Cursor`] to read all the bytes in a byte slice + /// in hyphen delimited segments: + /// + /// [`Cursor`]: std::io::Cursor + /// + /// ``` + /// use tokio::io::AsyncBufReadExt; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut cursor = Cursor::new(b"lorem-ipsum"); + /// + /// // cursor is at 'l' + /// let num_bytes = cursor.skip_until(b'-') + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 6); + /// + /// // cursor is at 'i' + /// let num_bytes = cursor.skip_until(b'-') + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 5); + /// + /// // cursor is at EOF + /// let num_bytes = cursor.skip_until(b'-') + /// .await + /// .expect("reading from cursor won't fail"); + /// assert_eq!(num_bytes, 0); + /// } + /// ``` + fn skip_until<'a>(&'a mut self, byte: u8) -> SkipUntil<'a, Self> + where + Self: Unpin, + { + skip_until(self, byte) + } + /// Reads all bytes until a newline (the 0xA byte) is reached, and append /// them to the provided buffer. /// diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index e8658b4326b..917fb5506c4 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -59,6 +59,7 @@ cfg_io_util! { mod read_to_string; mod read_until; + mod skip_until; mod repeat; pub use repeat::{repeat, Repeat}; diff --git a/tokio/src/io/util/skip_until.rs b/tokio/src/io/util/skip_until.rs new file mode 100644 index 00000000000..ec448eb9b9d --- /dev/null +++ b/tokio/src/io/util/skip_until.rs @@ -0,0 +1,70 @@ +use crate::io::AsyncBufRead; +use crate::util::memchr; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::mem; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +pin_project! { + /// Future for the [`skip_until`](crate::io::AsyncBufReadExt::skip_until) method. + /// The delimiter is included in the resulting vector. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct SkipUntil<'a, R: ?Sized> { + reader: &'a mut R, + delimiter: u8, + // The number of bytes skipped. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn skip_until<'a, R>(reader: &'a mut R, delimiter: u8) -> SkipUntil<'a, R> +where + R: AsyncBufRead + ?Sized + Unpin, +{ + SkipUntil { + reader, + delimiter, + read: 0, + _pin: PhantomPinned, + } +} + +pub(super) fn skip_until_internal( + mut reader: Pin<&mut R>, + cx: &mut Context<'_>, + delimiter: u8, + read: &mut usize, +) -> Poll> { + loop { + let (done, used) = { + let available = ready!(reader.as_mut().poll_fill_buf(cx))?; + if let Some(i) = memchr::memchr(delimiter, available) { + (true, i + 1) + } else { + (false, available.len()) + } + }; + reader.as_mut().consume(used); + *read += used; + if done || used == 0 { + return Poll::Ready(Ok(mem::replace(read, 0))); + } + } +} + +impl Future for SkipUntil<'_, R> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let me = self.project(); + skip_until_internal(Pin::new(*me.reader), cx, *me.delimiter, me.read) + } +} diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index c9cedc38b02..1a9224d496e 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -604,6 +604,7 @@ async_assert_fn!(tokio::io::stdout(): Send & Sync & Unpin); async_assert_fn!(tokio::io::Split>::next_segment(_): Send & Sync & !Unpin); async_assert_fn!(tokio::io::Lines>::next_line(_): Send & Sync & !Unpin); async_assert_fn!(tokio::io::AsyncBufReadExt::read_until(&mut BoxAsyncRead, u8, &mut Vec): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncBufReadExt::skip_until(&mut BoxAsyncRead, u8): Send & Sync & !Unpin); async_assert_fn!( tokio::io::AsyncBufReadExt::read_line(&mut BoxAsyncRead, &mut String): Send & Sync & !Unpin ); diff --git a/tokio/tests/io_skip_until.rs b/tokio/tests/io_skip_until.rs new file mode 100644 index 00000000000..d4c6880bee8 --- /dev/null +++ b/tokio/tests/io_skip_until.rs @@ -0,0 +1,55 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; + +#[tokio::test] +async fn skip_until() { + let mut rd: &[u8] = b"hello world"; + + let n = assert_ok!(rd.skip_until(b' ').await); + assert_eq!(n, 6); + let n = assert_ok!(rd.skip_until(b' ').await); + assert_eq!(n, 5); + let n = assert_ok!(rd.skip_until(b' ').await); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn skip_until_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld#Fizz\xffBuz") + .read(b"z#1#2") + .build(); + + let mut read = BufReader::new(mock); + + let bytes = read.skip_until(b'#').await.unwrap(); + assert_eq!(bytes, b"Hello World#".len()); + + let bytes = read.skip_until(b'#').await.unwrap(); + assert_eq!(bytes, b"Fizz\xffBuzz\n".len()); + + let bytes = read.skip_until(b'#').await.unwrap(); + assert_eq!(bytes, 2); + + let bytes = read.skip_until(b'#').await.unwrap(); + assert_eq!(bytes, 1); +} + +#[tokio::test] +async fn skip_until_fail() { + let mock = Builder::new() + .read(b"Hello \xffWor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let err = read.skip_until(b'#').await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); +}