Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use streams::StreamsState;
use streams::StreamsState;
pub use streams::{
Chunks, ClosedStream, FinishError, ReadError, ReadableError, RecvStream, SendStream,
ShouldTransmit, StreamEvent, Streams, WriteError, Written,
ShouldTransmit, StreamEvent, Streams, WriteError, Written, stage_buf, stage_chunks,
};

mod timer;
Expand Down
85 changes: 50 additions & 35 deletions quinn-proto/src/connection/streams/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,8 @@ impl<'a> SendStream<'a> {
///
/// Returns the number of bytes successfully written.
pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
self.write_source(|limit, chunks| {
let prefix = &data[..limit.min(data.len())];
chunks.push(prefix.to_vec().into());
prefix.len()
})
self.write_source(|limit, chunks| stage_buf(data, limit, chunks))
.map_err(|(e, _)| e)
}

/// Send data on the given stream
Expand All @@ -234,24 +231,8 @@ impl<'a> SendStream<'a> {
/// [`Written::chunks`] will not count this chunk as fully written. However
/// the chunk will be advanced and contain only non-written data after the call.
pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
self.write_source(|limit, chunks| {
let mut written = Written::default();
for chunk in data {
let prefix = chunk.split_to(chunk.len().min(limit - written.bytes));
written.bytes += prefix.len();
chunks.push(prefix);

if chunk.is_empty() {
written.chunks += 1;
}

debug_assert!(written.bytes <= limit);
if written.bytes == limit {
break;
}
}
written
})
self.write_source(|limit, chunks| stage_chunks(data, limit, chunks))
.map_err(|(e, _)| e)
}

/// Send data on the given stream
Expand All @@ -262,25 +243,23 @@ impl<'a> SendStream<'a> {
/// guaranteed they will all be written. If it provides more bytes than this, it is guaranteed
/// that a prefix of the provided cumulative bytes will be written equal in length to the
/// provided limit.
fn write_source<T>(
&mut self,
source: impl FnOnce(usize, &mut Vec<Bytes>) -> T,
) -> Result<T, WriteError> {
pub fn write_source<F, R>(&mut self, source: F) -> Result<R, (WriteError, F)>
where
F: FnOnce(usize, &mut Vec<Bytes>) -> R,
{
if self.conn_state.is_closed() {
trace!(%self.id, "write blocked; connection draining");
return Err(WriteError::Blocked);
return Err((WriteError::Blocked, source));
}

let limit = self.state.write_limit();

let max_send_data = self.state.max_send_data(self.id);

let stream = self
.state
.send
.get_mut(&self.id)
.map(get_or_insert_send(max_send_data))
.ok_or(WriteError::ClosedStream)?;
let stream = match self.state.send.get_mut(&self.id) {
Some(opt) => opt.get_or_insert_with(|| Send::new(max_send_data)),
None => return Err((WriteError::ClosedStream, source)),
};

if limit == 0 {
trace!(
Expand All @@ -291,7 +270,7 @@ impl<'a> SendStream<'a> {
stream.connection_blocked = true;
self.state.connection_blocked.push(self.id);
}
return Err(WriteError::Blocked);
return Err((WriteError::Blocked, source));
}

let was_pending = stream.is_pending();
Expand Down Expand Up @@ -398,6 +377,42 @@ impl<'a> SendStream<'a> {
}
}

/// Helper function for using [`SendStream::write_source`] with `&[u8]`
///
/// Copies the largest prefix of `data` that is not longer than `limit` to a new `Bytes`
/// allocation, pushes it to `chunks`, and returns how many bytes were transferred.
pub fn stage_buf(data: &[u8], limit: usize, chunks: &mut Vec<Bytes>) -> usize {
let prefix = &data[..limit.min(data.len())];
chunks.push(prefix.to_vec().into());
prefix.len()
}

/// Helper function for using [`SendStream::write_source`] with `&mut [Bytes]`
///
/// Treats `data` as a byte sequences represented as an array of contiguous chunks. Takes the
/// largest prefix of those bytes that is not longer than `limit` and pushes it to `chunks`.
/// Mutates each element of `data` so they no longer contain the parts of the chunks that were
/// taken. Returns a [`Written`] indicating the number of chunks that were *fully* transferred as
/// well as the total number of bytes that were transferred.
pub fn stage_chunks(data: &mut [Bytes], limit: usize, chunks: &mut Vec<Bytes>) -> Written {
let mut written = Written::default();
for chunk in data {
let prefix = chunk.split_to(chunk.len().min(limit - written.bytes));
written.bytes += prefix.len();
chunks.push(prefix);

if chunk.is_empty() {
written.chunks += 1;
}

debug_assert!(written.bytes <= limit);
if written.bytes == limit {
break;
}
}
written
}

/// A queue of streams with pending outgoing data, sorted by priority
struct PendingStreamsQueue {
streams: BinaryHeap<PendingStream>,
Expand Down
15 changes: 9 additions & 6 deletions quinn-proto/src/connection/streams/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,23 @@ impl Send {
}
}

pub(super) fn write<T>(
pub(super) fn write<F, R>(
&mut self,
source: impl FnOnce(usize, &mut Vec<Bytes>) -> T,
source: F,
limit: u64,
) -> Result<(usize, T), WriteError> {
) -> Result<(usize, R), (WriteError, F)>
where
F: FnOnce(usize, &mut Vec<Bytes>) -> R,
{
if !self.is_writable() {
return Err(WriteError::ClosedStream);
return Err((WriteError::ClosedStream, source));
}
if let Some(error_code) = self.stop_reason {
return Err(WriteError::Stopped(error_code));
return Err((WriteError::Stopped(error_code), source));
}
let budget = self.max_data - self.pending.offset();
if budget == 0 {
return Err(WriteError::Blocked);
return Err((WriteError::Blocked, source));
}
let limit = limit.min(budget) as usize;

Expand Down
2 changes: 1 addition & 1 deletion quinn-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub use crate::connection::{
Chunk, Chunks, ClosedStream, Connection, ConnectionError, ConnectionStats, Datagrams, Event,
FinishError, FrameStats, PathStats, ReadError, ReadableError, RecvStream, RttEstimator,
SendDatagramError, SendStream, ShouldTransmit, StreamEvent, Streams, UdpStats, WriteError,
Written,
Written, stage_buf, stage_chunks,
};

#[cfg(feature = "rustls")]
Expand Down
137 changes: 105 additions & 32 deletions quinn/src/send_stream.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::{
fmt,
future::{Future, poll_fn},
io,
pin::Pin,
pin::{Pin, pin},
task::{Context, Poll},
};

use bytes::Bytes;
use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
use proto::{
ClosedStream, ConnectionError, FinishError, StreamId, Written, stage_buf, stage_chunks,
};
use thiserror::Error;

use crate::{VarInt, connection::ConnectionRef};
Expand Down Expand Up @@ -50,7 +53,9 @@ impl SendStream {
///
/// This operation is cancel-safe.
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
self.write_source(|limit, chunks| stage_buf(buf, limit, chunks))
.await
.map_err(Into::into)
}

/// Convenience method to write an entire buffer to the stream
Expand All @@ -72,7 +77,9 @@ impl SendStream {
///
/// This operation is cancel-safe.
pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
self.write_source(|limit, chunks| stage_chunks(bufs, limit, chunks))
.await
.map_err(Into::into)
}

/// Convenience method to write a single chunk in its entirety to the stream
Expand All @@ -94,36 +101,64 @@ impl SendStream {
Ok(())
}

fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
/// Attempts to write bytes into this stream from a byte-providing callback
///
/// This is a low-level writing API that can be used to perform writes in a way that is atomic
/// with respect to congestion and flow control. This method:
///
/// 1. Waits until a non-zero number of bytes can be written (or an error occurs).
/// 2. Locks the internal connection state.
/// 3. Invokes the `source` callback with the number of bytes that can be written immediately,
/// as well as an initially empty `&mut Vec<Bytes>` to which it can push bytes to write.
/// 4. Immediately writes as many of those bytes as can be written immediately.
///
/// If the `source` callback pushes a total number of bytes to its vec less than or equal to
/// its given limit, it is guaranteed they will all be written into the stream immediately.
pub async fn write_source<F, R>(&mut self, source: F) -> Result<R, WriteSourceError<F>>
where
F: FnOnce(&mut proto::SendStream) -> Result<R, proto::WriteError>,
F: FnOnce(usize, &mut Vec<Bytes>) -> R,
{
use proto::WriteError::*;
let mut conn = self.conn.state.lock("SendStream::poll_write");
if self.is_0rtt {
conn.check_0rtt()
.map_err(|()| WriteError::ZeroRttRejected)?;
}
if let Some(ref x) = conn.error {
return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
}

let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
Ok(result) => result,
Err(Blocked) => {
conn.blocked_writers.insert(self.stream, cx.waker().clone());
return Poll::Pending;
let mut source = Some(source);
poll_fn(move |cx| {
let mut conn = self.conn.state.lock("SendStream::write_source");
if self.is_0rtt && conn.check_0rtt() == Err(()) {
return Poll::Ready(Err(WriteSourceError {
error: WriteError::ZeroRttRejected,
source: source.take().unwrap(),
}));
}
Err(Stopped(error_code)) => {
return Poll::Ready(Err(WriteError::Stopped(error_code)));
if let Some(e) = conn.error.clone() {
return Poll::Ready(Err(WriteSourceError {
error: WriteError::ConnectionLost(e),
source: source.take().unwrap(),
}));
}
Err(ClosedStream) => {
return Poll::Ready(Err(WriteError::ClosedStream));
}
};

conn.wake();
Poll::Ready(Ok(result))
match conn
.inner
.send_stream(self.stream)
.write_source(source.take().unwrap())
{
Ok(source_output) => {
conn.wake();
Poll::Ready(Ok(source_output))
}
Err((proto::WriteError::Blocked, returned_source)) => {
source = Some(returned_source);
conn.blocked_writers.insert(self.stream, cx.waker().clone());
Poll::Pending
}
Err((e, returned_source)) => Poll::Ready(Err(WriteSourceError {
error: match e {
proto::WriteError::Blocked => unreachable!(),
proto::WriteError::Stopped(code) => WriteError::Stopped(code),
proto::WriteError::ClosedStream => WriteError::ClosedStream,
},
source: returned_source,
})),
}
})
.await
}

/// Notify the peer that no more data will ever be written to this stream
Expand Down Expand Up @@ -241,14 +276,14 @@ impl SendStream {
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, WriteError>> {
self.get_mut().execute_poll(cx, |stream| stream.write(buf))
pin!(self.get_mut().write(buf)).as_mut().poll(cx)
}
}

#[cfg(feature = "futures-io")]
impl futures_io::AsyncWrite for SendStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
self.poll_write(cx, buf).map_err(Into::into)
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Expand All @@ -266,7 +301,7 @@ impl tokio::io::AsyncWrite for SendStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
self.poll_write(cx, buf).map_err(Into::into)
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Expand Down Expand Up @@ -355,6 +390,12 @@ impl From<StoppedError> for WriteError {
}
}

impl<F> From<WriteSourceError<F>> for WriteError {
fn from(e: WriteSourceError<F>) -> Self {
e.error
}
}

impl From<WriteError> for io::Error {
fn from(x: WriteError) -> Self {
use WriteError::*;
Expand Down Expand Up @@ -392,3 +433,35 @@ impl From<StoppedError> for io::Error {
Self::new(kind, x)
}
}

/// Error type for [`SendStream::write_source`]
pub struct WriteSourceError<F> {
/// The underlying write error
pub error: WriteError,
/// The `source` parameter that was passed to `write_source`
pub source: F,
}

impl<F> fmt::Debug for WriteSourceError<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.error, f)
}
}

impl<F> fmt::Display for WriteSourceError<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.error, f)
}
}

impl<F> std::error::Error for WriteSourceError<F> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.error)
}
}

impl<F> From<WriteSourceError<F>> for io::Error {
fn from(e: WriteSourceError<F>) -> Self {
e.error.into()
}
}
Loading