diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 10a4fc0efd..484b130b83 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -81,7 +81,7 @@ pub use streams::StreamsState; use streams::StreamsState; pub use streams::{ Chunks, ClosedStream, FinishError, ReadError, ReadableError, RecvStream, SendStream, - ShouldTransmit, StreamEvent, Streams, WriteError, Written, + ShouldTransmit, StreamEvent, Streams, WriteError, Written, stage_buf, stage_chunks, }; mod timer; diff --git a/quinn-proto/src/connection/streams/mod.rs b/quinn-proto/src/connection/streams/mod.rs index 70f51e39ba..d7c037c175 100644 --- a/quinn-proto/src/connection/streams/mod.rs +++ b/quinn-proto/src/connection/streams/mod.rs @@ -220,11 +220,8 @@ impl<'a> SendStream<'a> { /// /// Returns the number of bytes successfully written. pub fn write(&mut self, data: &[u8]) -> Result { - self.write_source(|limit, chunks| { - let prefix = &data[..limit.min(data.len())]; - chunks.push(prefix.to_vec().into()); - prefix.len() - }) + self.write_source(|limit, chunks| stage_buf(data, limit, chunks)) + .map_err(|(e, _)| e) } /// Send data on the given stream @@ -234,24 +231,8 @@ impl<'a> SendStream<'a> { /// [`Written::chunks`] will not count this chunk as fully written. However /// the chunk will be advanced and contain only non-written data after the call. pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result { - self.write_source(|limit, chunks| { - let mut written = Written::default(); - for chunk in data { - let prefix = chunk.split_to(chunk.len().min(limit - written.bytes)); - written.bytes += prefix.len(); - chunks.push(prefix); - - if chunk.is_empty() { - written.chunks += 1; - } - - debug_assert!(written.bytes <= limit); - if written.bytes == limit { - break; - } - } - written - }) + self.write_source(|limit, chunks| stage_chunks(data, limit, chunks)) + .map_err(|(e, _)| e) } /// Send data on the given stream @@ -262,25 +243,23 @@ impl<'a> SendStream<'a> { /// guaranteed they will all be written. If it provides more bytes than this, it is guaranteed /// that a prefix of the provided cumulative bytes will be written equal in length to the /// provided limit. - fn write_source( - &mut self, - source: impl FnOnce(usize, &mut Vec) -> T, - ) -> Result { + pub fn write_source(&mut self, source: F) -> Result + where + F: FnOnce(usize, &mut Vec) -> R, + { if self.conn_state.is_closed() { trace!(%self.id, "write blocked; connection draining"); - return Err(WriteError::Blocked); + return Err((WriteError::Blocked, source)); } let limit = self.state.write_limit(); let max_send_data = self.state.max_send_data(self.id); - let stream = self - .state - .send - .get_mut(&self.id) - .map(get_or_insert_send(max_send_data)) - .ok_or(WriteError::ClosedStream)?; + let stream = match self.state.send.get_mut(&self.id) { + Some(opt) => opt.get_or_insert_with(|| Send::new(max_send_data)), + None => return Err((WriteError::ClosedStream, source)), + }; if limit == 0 { trace!( @@ -291,7 +270,7 @@ impl<'a> SendStream<'a> { stream.connection_blocked = true; self.state.connection_blocked.push(self.id); } - return Err(WriteError::Blocked); + return Err((WriteError::Blocked, source)); } let was_pending = stream.is_pending(); @@ -398,6 +377,42 @@ impl<'a> SendStream<'a> { } } +/// Helper function for using [`SendStream::write_source`] with `&[u8]` +/// +/// Copies the largest prefix of `data` that is not longer than `limit` to a new `Bytes` +/// allocation, pushes it to `chunks`, and returns how many bytes were transferred. +pub fn stage_buf(data: &[u8], limit: usize, chunks: &mut Vec) -> usize { + let prefix = &data[..limit.min(data.len())]; + chunks.push(prefix.to_vec().into()); + prefix.len() +} + +/// Helper function for using [`SendStream::write_source`] with `&mut [Bytes]` +/// +/// Treats `data` as a byte sequences represented as an array of contiguous chunks. Takes the +/// largest prefix of those bytes that is not longer than `limit` and pushes it to `chunks`. +/// Mutates each element of `data` so they no longer contain the parts of the chunks that were +/// taken. Returns a [`Written`] indicating the number of chunks that were *fully* transferred as +/// well as the total number of bytes that were transferred. +pub fn stage_chunks(data: &mut [Bytes], limit: usize, chunks: &mut Vec) -> Written { + let mut written = Written::default(); + for chunk in data { + let prefix = chunk.split_to(chunk.len().min(limit - written.bytes)); + written.bytes += prefix.len(); + chunks.push(prefix); + + if chunk.is_empty() { + written.chunks += 1; + } + + debug_assert!(written.bytes <= limit); + if written.bytes == limit { + break; + } + } + written +} + /// A queue of streams with pending outgoing data, sorted by priority struct PendingStreamsQueue { streams: BinaryHeap, diff --git a/quinn-proto/src/connection/streams/send.rs b/quinn-proto/src/connection/streams/send.rs index 52a9b7140d..3a211960b0 100644 --- a/quinn-proto/src/connection/streams/send.rs +++ b/quinn-proto/src/connection/streams/send.rs @@ -52,20 +52,23 @@ impl Send { } } - pub(super) fn write( + pub(super) fn write( &mut self, - source: impl FnOnce(usize, &mut Vec) -> T, + source: F, limit: u64, - ) -> Result<(usize, T), WriteError> { + ) -> Result<(usize, R), (WriteError, F)> + where + F: FnOnce(usize, &mut Vec) -> R, + { if !self.is_writable() { - return Err(WriteError::ClosedStream); + return Err((WriteError::ClosedStream, source)); } if let Some(error_code) = self.stop_reason { - return Err(WriteError::Stopped(error_code)); + return Err((WriteError::Stopped(error_code), source)); } let budget = self.max_data - self.pending.offset(); if budget == 0 { - return Err(WriteError::Blocked); + return Err((WriteError::Blocked, source)); } let limit = limit.min(budget) as usize; diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 74bd7e60a6..d477316ad2 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -47,7 +47,7 @@ pub use crate::connection::{ Chunk, Chunks, ClosedStream, Connection, ConnectionError, ConnectionStats, Datagrams, Event, FinishError, FrameStats, PathStats, ReadError, ReadableError, RecvStream, RttEstimator, SendDatagramError, SendStream, ShouldTransmit, StreamEvent, Streams, UdpStats, WriteError, - Written, + Written, stage_buf, stage_chunks, }; #[cfg(feature = "rustls")] diff --git a/quinn/src/send_stream.rs b/quinn/src/send_stream.rs index 91c10c6bc0..fcea07cf98 100644 --- a/quinn/src/send_stream.rs +++ b/quinn/src/send_stream.rs @@ -1,12 +1,15 @@ use std::{ + fmt, future::{Future, poll_fn}, io, - pin::Pin, + pin::{Pin, pin}, task::{Context, Poll}, }; use bytes::Bytes; -use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written}; +use proto::{ + ClosedStream, ConnectionError, FinishError, StreamId, Written, stage_buf, stage_chunks, +}; use thiserror::Error; use crate::{VarInt, connection::ConnectionRef}; @@ -50,7 +53,9 @@ impl SendStream { /// /// This operation is cancel-safe. pub async fn write(&mut self, buf: &[u8]) -> Result { - poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await + self.write_source(|limit, chunks| stage_buf(buf, limit, chunks)) + .await + .map_err(Into::into) } /// Convenience method to write an entire buffer to the stream @@ -72,7 +77,9 @@ impl SendStream { /// /// This operation is cancel-safe. pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result { - poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await + self.write_source(|limit, chunks| stage_chunks(bufs, limit, chunks)) + .await + .map_err(Into::into) } /// Convenience method to write a single chunk in its entirety to the stream @@ -94,36 +101,64 @@ impl SendStream { Ok(()) } - fn execute_poll(&mut self, cx: &mut Context, write_fn: F) -> Poll> + /// Attempts to write bytes into this stream from a byte-providing callback + /// + /// This is a low-level writing API that can be used to perform writes in a way that is atomic + /// with respect to congestion and flow control. This method: + /// + /// 1. Waits until a non-zero number of bytes can be written (or an error occurs). + /// 2. Locks the internal connection state. + /// 3. Invokes the `source` callback with the number of bytes that can be written immediately, + /// as well as an initially empty `&mut Vec` to which it can push bytes to write. + /// 4. Immediately writes as many of those bytes as can be written immediately. + /// + /// If the `source` callback pushes a total number of bytes to its vec less than or equal to + /// its given limit, it is guaranteed they will all be written into the stream immediately. + pub async fn write_source(&mut self, source: F) -> Result> where - F: FnOnce(&mut proto::SendStream) -> Result, + F: FnOnce(usize, &mut Vec) -> R, { - use proto::WriteError::*; - let mut conn = self.conn.state.lock("SendStream::poll_write"); - if self.is_0rtt { - conn.check_0rtt() - .map_err(|()| WriteError::ZeroRttRejected)?; - } - if let Some(ref x) = conn.error { - return Poll::Ready(Err(WriteError::ConnectionLost(x.clone()))); - } - - let result = match write_fn(&mut conn.inner.send_stream(self.stream)) { - Ok(result) => result, - Err(Blocked) => { - conn.blocked_writers.insert(self.stream, cx.waker().clone()); - return Poll::Pending; + let mut source = Some(source); + poll_fn(move |cx| { + let mut conn = self.conn.state.lock("SendStream::write_source"); + if self.is_0rtt && conn.check_0rtt() == Err(()) { + return Poll::Ready(Err(WriteSourceError { + error: WriteError::ZeroRttRejected, + source: source.take().unwrap(), + })); } - Err(Stopped(error_code)) => { - return Poll::Ready(Err(WriteError::Stopped(error_code))); + if let Some(e) = conn.error.clone() { + return Poll::Ready(Err(WriteSourceError { + error: WriteError::ConnectionLost(e), + source: source.take().unwrap(), + })); } - Err(ClosedStream) => { - return Poll::Ready(Err(WriteError::ClosedStream)); - } - }; - conn.wake(); - Poll::Ready(Ok(result)) + match conn + .inner + .send_stream(self.stream) + .write_source(source.take().unwrap()) + { + Ok(source_output) => { + conn.wake(); + Poll::Ready(Ok(source_output)) + } + Err((proto::WriteError::Blocked, returned_source)) => { + source = Some(returned_source); + conn.blocked_writers.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + Err((e, returned_source)) => Poll::Ready(Err(WriteSourceError { + error: match e { + proto::WriteError::Blocked => unreachable!(), + proto::WriteError::Stopped(code) => WriteError::Stopped(code), + proto::WriteError::ClosedStream => WriteError::ClosedStream, + }, + source: returned_source, + })), + } + }) + .await } /// Notify the peer that no more data will ever be written to this stream @@ -241,14 +276,14 @@ impl SendStream { cx: &mut Context, buf: &[u8], ) -> Poll> { - self.get_mut().execute_poll(cx, |stream| stream.write(buf)) + pin!(self.get_mut().write(buf)).as_mut().poll(cx) } } #[cfg(feature = "futures-io")] impl futures_io::AsyncWrite for SendStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into) + self.poll_write(cx, buf).map_err(Into::into) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { @@ -266,7 +301,7 @@ impl tokio::io::AsyncWrite for SendStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into) + self.poll_write(cx, buf).map_err(Into::into) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { @@ -355,6 +390,12 @@ impl From for WriteError { } } +impl From> for WriteError { + fn from(e: WriteSourceError) -> Self { + e.error + } +} + impl From for io::Error { fn from(x: WriteError) -> Self { use WriteError::*; @@ -392,3 +433,35 @@ impl From for io::Error { Self::new(kind, x) } } + +/// Error type for [`SendStream::write_source`] +pub struct WriteSourceError { + /// The underlying write error + pub error: WriteError, + /// The `source` parameter that was passed to `write_source` + pub source: F, +} + +impl fmt::Debug for WriteSourceError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.error, f) + } +} + +impl fmt::Display for WriteSourceError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.error, f) + } +} + +impl std::error::Error for WriteSourceError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.error) + } +} + +impl From> for io::Error { + fn from(e: WriteSourceError) -> Self { + e.error.into() + } +}