From ab6684815c07f781f6c3b57a05a37d4b6fbcb2ca Mon Sep 17 00:00:00 2001 From: Phoenix Kahlo Date: Tue, 13 May 2025 00:16:04 -0500 Subject: [PATCH] proto: Replace write_source with WriteGuard<'_> This commit removes the ByteSource trait and the write_source method. Instead, a new method is introduced as such: fn SendStream::write_guard(&mut self) -> Result, WriteError> The resultant WriteGuard contains the limit of how many bytes could be written now, as well as borrows of the fields necessary to perform those writes. The write and write_chunks methods are written to more simply use this. --- quinn-proto/src/connection/streams/mod.rs | 68 +++++-- quinn-proto/src/connection/streams/send.rs | 222 +-------------------- 2 files changed, 56 insertions(+), 234 deletions(-) diff --git a/quinn-proto/src/connection/streams/mod.rs b/quinn-proto/src/connection/streams/mod.rs index 53e42815ee..ced9fe24e8 100644 --- a/quinn-proto/src/connection/streams/mod.rs +++ b/quinn-proto/src/connection/streams/mod.rs @@ -19,9 +19,8 @@ use recv::Recv; pub use recv::{Chunks, ReadError, ReadableError}; mod send; -pub(crate) use send::{ByteSlice, BytesArray}; -use send::{BytesSource, Send, SendState}; pub use send::{FinishError, WriteError, Written}; +use send::{Send, SendState}; mod state; #[allow(unreachable_pub)] // fuzzing only @@ -221,7 +220,10 @@ impl<'a> SendStream<'a> { /// /// Returns the number of bytes successfully written. pub fn write(&mut self, data: &[u8]) -> Result { - Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes) + let mut guard = self.write_guard()?; + let written = data.len().min(guard.limit); + guard.write(data[..written].to_vec().into()); + Ok(written) } /// Send data on the given stream @@ -231,10 +233,25 @@ impl<'a> SendStream<'a> { /// [`Written::chunks`] will not count this chunk as fully written. However /// the chunk will be advanced and contain only non-written data after the call. pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result { - self.write_source(&mut BytesArray::from_chunks(data)) + let mut guard = self.write_guard()?; + let mut written = Written::default(); + for chunk in data { + let prefix = chunk.split_to(chunk.len().min(guard.limit)); + written.bytes += prefix.len(); + guard.write(prefix); + + if chunk.is_empty() { + written.chunks += 1; + } + + if guard.limit == 0 { + break; + } + } + Ok(written) } - fn write_source(&mut self, source: &mut B) -> Result { + fn write_guard(&mut self) -> Result { if self.conn_state.is_closed() { trace!(%self.id, "write blocked; connection draining"); return Err(WriteError::Blocked); @@ -263,15 +280,16 @@ impl<'a> SendStream<'a> { return Err(WriteError::Blocked); } - let was_pending = stream.is_pending(); - let written = stream.write(source, limit)?; - self.state.data_sent += written.bytes as u64; - self.state.unacked_data += written.bytes as u64; - trace!(stream = %self.id, "wrote {} bytes", written.bytes); - if !was_pending { - self.state.pending.push_pending(self.id, stream.priority); - } - Ok(written) + let limit = stream.write_limit(limit)?; + + Ok(WriteGuard { + limit, + stream, + id: self.id, + data_sent: &mut self.state.data_sent, + unacked_data: &mut self.state.unacked_data, + pending: &mut self.state.pending, + }) } /// Check if this stream was stopped, get the reason if it was @@ -367,6 +385,28 @@ impl<'a> SendStream<'a> { } } +struct WriteGuard<'a> { + limit: usize, + id: StreamId, + stream: &'a mut Send, + data_sent: &'a mut u64, + unacked_data: &'a mut u64, + pending: &'a mut PendingStreamsQueue, +} + +impl<'a> WriteGuard<'a> { + fn write(&mut self, bytes: Bytes) { + self.limit -= bytes.len(); + *self.data_sent += bytes.len() as u64; + *self.unacked_data += bytes.len() as u64; + let was_pending = self.stream.is_pending(); + self.stream.pending.write(bytes); + if !was_pending { + self.pending.push_pending(self.id, self.stream.priority); + } + } +} + /// A queue of streams with pending outgoing data, sorted by priority struct PendingStreamsQueue { streams: BinaryHeap, diff --git a/quinn-proto/src/connection/streams/send.rs b/quinn-proto/src/connection/streams/send.rs index 7b3db809a1..759c04c48d 100644 --- a/quinn-proto/src/connection/streams/send.rs +++ b/quinn-proto/src/connection/streams/send.rs @@ -1,4 +1,3 @@ -use bytes::Bytes; use thiserror::Error; use crate::{VarInt, connection::send_buffer::SendBuffer, frame}; @@ -49,11 +48,7 @@ impl Send { } } - pub(super) fn write( - &mut self, - source: &mut S, - limit: u64, - ) -> Result { + pub(super) fn write_limit(&self, limit: u64) -> Result { if !self.is_writable() { return Err(WriteError::ClosedStream); } @@ -64,23 +59,7 @@ impl Send { if budget == 0 { return Err(WriteError::Blocked); } - let mut limit = limit.min(budget) as usize; - - let mut result = Written::default(); - loop { - let (chunk, chunks_consumed) = source.pop_chunk(limit); - result.chunks += chunks_consumed; - result.bytes += chunk.len(); - - if chunk.is_empty() { - break; - } - - limit -= chunk.len(); - self.pending.write(chunk); - } - - Ok(result) + Ok(limit.min(budget) as usize) } /// Update stream state due to a reset sent by the local application @@ -143,106 +122,6 @@ impl Send { } } -/// A [`BytesSource`] implementation for `&'a mut [Bytes]` -/// -/// The type allows to dequeue [`Bytes`] chunks from an array of chunks, up to -/// a configured limit. -pub(crate) struct BytesArray<'a> { - /// The wrapped slice of `Bytes` - chunks: &'a mut [Bytes], - /// The amount of chunks consumed from this source - consumed: usize, -} - -impl<'a> BytesArray<'a> { - pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self { - Self { - chunks, - consumed: 0, - } - } -} - -impl BytesSource for BytesArray<'_> { - fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) { - // The loop exists to skip empty chunks while still marking them as - // consumed - let mut chunks_consumed = 0; - - while self.consumed < self.chunks.len() { - let chunk = &mut self.chunks[self.consumed]; - - if chunk.len() <= limit { - let chunk = std::mem::take(chunk); - self.consumed += 1; - chunks_consumed += 1; - if chunk.is_empty() { - continue; - } - return (chunk, chunks_consumed); - } else if limit > 0 { - let chunk = chunk.split_to(limit); - return (chunk, chunks_consumed); - } else { - break; - } - } - - (Bytes::new(), chunks_consumed) - } -} - -/// A [`BytesSource`] implementation for `&[u8]` -/// -/// The type allows to dequeue a single [`Bytes`] chunk, which will be lazily -/// created from a reference. This allows to defer the allocation until it is -/// known how much data needs to be copied. -pub(crate) struct ByteSlice<'a> { - /// The wrapped byte slice - data: &'a [u8], -} - -impl<'a> ByteSlice<'a> { - pub(crate) fn from_slice(data: &'a [u8]) -> Self { - Self { data } - } -} - -impl BytesSource for ByteSlice<'_> { - fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) { - let limit = limit.min(self.data.len()); - if limit == 0 { - return (Bytes::new(), 0); - } - - let chunk = Bytes::from(self.data[..limit].to_owned()); - self.data = &self.data[chunk.len()..]; - - let chunks_consumed = usize::from(self.data.is_empty()); - (chunk, chunks_consumed) - } -} - -/// A source of one or more buffers which can be converted into `Bytes` buffers on demand -/// -/// The purpose of this data type is to defer conversion as long as possible, -/// so that no heap allocation is required in case no data is writable. -pub(super) trait BytesSource { - /// Returns the next chunk from the source of owned chunks. - /// - /// This method will consume parts of the source. - /// Calling it will yield `Bytes` elements up to the configured `limit`. - /// - /// The method returns a tuple: - /// - The first item is the yielded `Bytes` element. The element will be - /// empty if the limit is zero or no more data is available. - /// - The second item returns how many complete chunks inside the source had - /// had been consumed. This can be less than 1, if a chunk inside the - /// source had been truncated in order to adhere to the limit. It can also - /// be more than 1, if zero-length chunks had been skipped. - fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize); -} - /// Indicates how many bytes and chunks had been transferred in a write operation #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Written { @@ -303,100 +182,3 @@ pub enum FinishError { #[error("closed stream")] ClosedStream, } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn bytes_array() { - let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned(); - for limit in 0..full.len() { - let mut chunks = [ - Bytes::from_static(b""), - Bytes::from_static(b"Hello "), - Bytes::from_static(b"Wo"), - Bytes::from_static(b""), - Bytes::from_static(b"r"), - Bytes::from_static(b"ld"), - Bytes::from_static(b""), - Bytes::from_static(b" 12345678"), - Bytes::from_static(b"9 ABCDE"), - Bytes::from_static(b"F"), - Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"), - ]; - let num_chunks = chunks.len(); - let last_chunk_len = chunks[chunks.len() - 1].len(); - - let mut array = BytesArray::from_chunks(&mut chunks); - - let mut buf = Vec::new(); - let mut chunks_popped = 0; - let mut chunks_consumed = 0; - let mut remaining = limit; - loop { - let (chunk, consumed) = array.pop_chunk(remaining); - chunks_consumed += consumed; - - if !chunk.is_empty() { - buf.extend_from_slice(&chunk); - remaining -= chunk.len(); - chunks_popped += 1; - } else { - break; - } - } - - assert_eq!(&buf[..], &full[..limit]); - - if limit == full.len() { - // Full consumption of the last chunk - assert_eq!(chunks_consumed, num_chunks); - // Since there are empty chunks, we consume more than there are popped - assert_eq!(chunks_consumed, chunks_popped + 3); - } else if limit > full.len() - last_chunk_len { - // Partial consumption of the last chunk - assert_eq!(chunks_consumed, num_chunks - 1); - assert_eq!(chunks_consumed, chunks_popped + 2); - } - } - } - - #[test] - fn byte_slice() { - let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned(); - for limit in 0..full.len() { - let mut array = ByteSlice::from_slice(&full[..]); - - let mut buf = Vec::new(); - let mut chunks_popped = 0; - let mut chunks_consumed = 0; - let mut remaining = limit; - loop { - let (chunk, consumed) = array.pop_chunk(remaining); - chunks_consumed += consumed; - - if !chunk.is_empty() { - buf.extend_from_slice(&chunk); - remaining -= chunk.len(); - chunks_popped += 1; - } else { - break; - } - } - - assert_eq!(&buf[..], &full[..limit]); - if limit != 0 { - assert_eq!(chunks_popped, 1); - } else { - assert_eq!(chunks_popped, 0); - } - - if limit == full.len() { - assert_eq!(chunks_consumed, 1); - } else { - assert_eq!(chunks_consumed, 0); - } - } - } -}