diff --git a/mux/SPEC.md b/mux/SPEC.md index 7ae41e0..9833e7d 100644 --- a/mux/SPEC.md +++ b/mux/SPEC.md @@ -8,7 +8,7 @@ MUX is a symmetric stream multiplexing protocol designed to run over reliable, o - Lightweight framing with minimal overhead (14-byte headers) - Per-stream flow control with automatic window tuning -- Graceful and abrupt stream termination +- Fast stream open/close with ID reuse - Connection-level resource limits - Symmetric operation (no client/server distinction) - Deterministic stream IDs derived from user-defined identifiers @@ -98,15 +98,14 @@ Flags modify frame behavior. Multiple flags MAY be set simultaneously by combini | Flag | Value | Applicable Types | Description | |------|-------|------------------|-------------| -| FIN | `0x01` | Data, Window Update | Half-closes the stream in the sender's direction. | -| RST | `0x02` | Data, Window Update | Immediately resets (terminates) the stream. | +| FIN | `0x01` | Data | Signals end-of-stream marker (see Section 5.4). | | SYN | `0x04` | Ping | Ping request. | | ACK | `0x08` | Ping | Ping response. | ### 4.1 Flag Constraints -- FIN and RST MUST NOT be set together. If both are set, RST takes precedence. - SYN and ACK are only valid on Ping frames. +- FIN is only valid on Data frames. ## 5. Stream Management @@ -141,46 +140,45 @@ Streams are created implicitly when the first frame for a Stream ID is sent or r ### 5.3 Stream Lifecycle -``` - +-------+ - | Open | - +-------+ - / | \ - FIN← / | \ →FIN - / | \ - +----------+ | +----------+ - |RecvClosed| | |SendClosed| - +----------+ | +----------+ - \ | / - FIN→ \ | / ←FIN - \ | / - +-------+ - |Closed | - +-------+ - ↑ - RST (any state) -``` +Streams have a simplified lifecycle optimized for fast open/close with ID reuse: + +**Local stream states:** +- **Active**: Application holds a handle to the stream. Can read/write. +- **Buffering**: No handle, but buffered data exists. Data continues to be received. +- **Removed**: No handle and no buffered data. Stream ID available for reuse. + +**Opening a stream:** +1. Compute Stream ID from user identifier. +2. If stream exists in Buffering state, transition to Active (access buffered data). +3. If stream does not exist, create new stream in Active state. -| State | Description | -|-------|-------------| -| Open | Bidirectional data flow. | -| SendClosed | Local side sent FIN. Can still receive data. | -| RecvClosed | Remote side sent FIN. Can still send data. | -| Closed | Both directions closed. Stream resources may be released. | +**Closing a stream:** +1. Application drops the stream handle. +2. If buffer is empty, remove stream (ID available for reuse). +3. If buffer is non-empty, transition to Buffering state. -### 5.4 Half-Close (FIN) +**Key properties:** +- Stream IDs can be reused after the stream is removed. +- Dropping a handle does NOT send any protocol message. +- Remote peer is not notified when local handle is dropped. +- Flow control remains active for Buffering streams. -Sending FIN indicates the sender will transmit no more data on this stream. The stream transitions: -- `Open` → `SendClosed` -- `RecvClosed` → `Closed` +### 5.4 End-of-Stream Marker (FIN) -Receiving FIN transitions: -- `Open` → `RecvClosed` -- `SendClosed` → `Closed` +FIN is an in-band marker signaling the end of a logical message or stream segment. Unlike traditional half-close: -### 5.5 Reset (RST) +- FIN is buffered in-order with data frames. +- Reading FIN returns EOF (0 bytes). +- **Data MAY be sent after FIN.** FIN does not prevent further transmission. +- Applications use FIN for framing; the protocol does not enforce termination. + +**Example usage:** +``` +Sender: Data("hello") → FIN → Data("world") → FIN +Reader: reads "hello" → reads EOF → reads "world" → reads EOF +``` -RST immediately terminates a stream from any state. Both sides SHOULD release stream resources upon sending or receiving RST. No further frames SHOULD be sent on a reset stream. +This allows FIN to delimit messages within a long-lived stream. ## 6. Flow Control @@ -195,6 +193,7 @@ Each stream maintains an independent receive window representing the number of b **Behavior:** - Senders MUST NOT send more data than the receiver's advertised window. - Each byte of Data payload consumes one byte of window. +- Each FIN marker consumes 32 bytes of window (to prevent FIN spam attacks). - Window Update frames replenish the window. ### 6.2 Window Updates @@ -267,7 +266,7 @@ Upon detecting a protocol violation, implementations MUST: ### 8.2 Stream Errors vs Connection Errors -- **Stream errors** (e.g., application-level errors) SHOULD be handled with RST on that stream. +- **Stream errors** (e.g., application-level errors) are handled by the application dropping the stream handle. The remote peer is not explicitly notified. - **Connection errors** (e.g., protocol violations) MUST be handled with GoAway and connection closure. ## 9. Constants Summary diff --git a/mux/mux/src/chunks.rs b/mux/mux/src/chunks.rs index aa299ce..c761227 100644 --- a/mux/mux/src/chunks.rs +++ b/mux/mux/src/chunks.rs @@ -12,14 +12,21 @@ use std::{collections::VecDeque, io}; -/// A sequence of [`Chunk`] values. +/// An element in the buffer - either data or a FIN marker. +#[derive(Debug)] +pub(crate) enum ChunkOrFin { + Chunk(Chunk), + Fin, +} + +/// A sequence of [`ChunkOrFin`] values. /// /// [`Chunks::len`] considers all [`Chunk`] elements and computes the total /// result, i.e. the length of all bytes, by summing up the lengths of all -/// [`Chunk`] elements. +/// [`Chunk`] elements. FIN markers don't contribute to length. #[derive(Debug)] pub(crate) struct Chunks { - seq: VecDeque, + seq: VecDeque, len: usize, } @@ -34,29 +41,61 @@ impl Chunks { /// The total length of bytes yet-to-be-read in all `Chunk`s. pub(crate) fn len(&self) -> usize { - self.len - self.seq.front().map(|c| c.offset()).unwrap_or(0) + let front_offset = self + .seq + .front() + .and_then(|e| match e { + ChunkOrFin::Chunk(c) => Some(c.offset()), + ChunkOrFin::Fin => None, + }) + .unwrap_or(0); + self.len - front_offset + } + + /// Returns true if there is no data in the buffer. + /// + /// Note: A buffer with only FIN markers is considered empty. + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 } /// Add another chunk of bytes to the end. pub(crate) fn push(&mut self, x: Vec) { self.len += x.len(); if !x.is_empty() { - self.seq.push_back(Chunk { + self.seq.push_back(ChunkOrFin::Chunk(Chunk { cursor: io::Cursor::new(x), - }) + })) + } + } + + /// Add a FIN marker to the end. + pub(crate) fn push_fin(&mut self) { + self.seq.push_back(ChunkOrFin::Fin); + } + + /// Remove and return the first element. + pub(crate) fn pop(&mut self) -> Option { + let elem = self.seq.pop_front(); + if let Some(ChunkOrFin::Chunk(ref c)) = elem { + self.len -= c.len() + c.offset(); } + elem } - /// Remove and return the first chunk. - pub(crate) fn pop(&mut self) -> Option { - let chunk = self.seq.pop_front(); - self.len -= chunk.as_ref().map(|c| c.len() + c.offset()).unwrap_or(0); - chunk + /// Get a reference to the first element. + pub(crate) fn front(&self) -> Option<&ChunkOrFin> { + self.seq.front() } - /// Get a mutable reference to the first chunk. - pub(crate) fn front_mut(&mut self) -> Option<&mut Chunk> { - self.seq.front_mut() + /// Get a mutable reference to the first chunk, if it is a chunk. + /// + /// Returns None if buffer is empty or front is a FIN marker. + pub(crate) fn front_chunk_mut(&mut self) -> Option<&mut Chunk> { + match self.seq.front_mut() { + Some(ChunkOrFin::Chunk(c)) => Some(c), + _ => None, + } } } diff --git a/mux/mux/src/connection.rs b/mux/mux/src/connection.rs index 9962c5b..718d79b 100644 --- a/mux/mux/src/connection.rs +++ b/mux/mux/src/connection.rs @@ -84,7 +84,7 @@ impl Connection { matches!(self.inner, ConnectionState::Closed(_)) } - /// Get a handle for creating streams concurrently. + /// Get a handle for obtaining streams concurrently. /// /// The handle can be cloned and used from multiple tasks while the /// Connection is being polled. @@ -95,17 +95,21 @@ impl Connection { } } - /// Create a new stream with the given user ID. + /// Get a stream handle for the given user ID. + /// + /// Streams are implicit - if data has already been received for this ID, + /// returns a handle to the existing stream with buffered data. Otherwise, + /// creates a new stream. /// /// The stream ID is computed from the user ID using BLAKE3. - /// Either side can create streams with the same user ID - they will - /// automatically merge into the same stream. + /// Either side can get streams with the same user ID - they will + /// automatically refer to the same stream. /// - /// The `user_id` parameter is a required user-defined stream identifier - /// (1-256 bytes). User IDs must be unique within the session. - pub fn new_stream(&mut self, user_id: &[u8]) -> Result { + /// The `user_id` parameter is a user-defined stream identifier (1-256 + /// bytes). + pub fn get_stream(&mut self, user_id: &[u8]) -> Result { match &mut self.inner { - ConnectionState::Active(active) => active.new_stream(user_id), + ConnectionState::Active(active) => active.get_stream(user_id), _ => Err(ConnectionError::Closed), } } diff --git a/mux/mux/src/connection/active.rs b/mux/mux/src/connection/active.rs index 714207f..9cc9b5f 100644 --- a/mux/mux/src/connection/active.rs +++ b/mux/mux/src/connection/active.rs @@ -28,16 +28,22 @@ use super::{ cleanup::Cleanup, closing::Closing, rtt, - stream::{self, State, Stream}, + stream::{self, Stream}, }; +/// Entry in the stream registry. +struct StreamEntry { + shared: Arc>, + has_handle: bool, +} + /// Shared state for stream management. /// /// This struct holds state that can be accessed by both the Connection's /// poll loop and Handle for concurrent stream creation. pub(crate) struct StreamRegistry { id: Id, - streams: IntMap>>, + streams: IntMap, new_receiver_tx: mpsc::UnboundedSender>>, waker: Option, config: Arc, @@ -66,7 +72,7 @@ impl StreamRegistry { } } - fn new_stream(&mut self, user_id: &[u8]) -> Result { + fn get_stream(&mut self, user_id: &[u8]) -> Result { let user_id = UserId::new(user_id)?; let stream_id = StreamId::new(user_id.as_bytes()); @@ -75,15 +81,23 @@ impl StreamRegistry { return Err(ConnectionError::TooManyStreams); } - // Check if stream already exists (created implicitly by remote) - if let Some(existing) = self.streams.get(&stream_id) { + // Check if stream already exists (created implicitly by remote or reopening) + if let Some(entry) = self.streams.get_mut(&stream_id) { log::trace!("{}/{}: merging with existing stream", self.id, stream_id); - let stream = self.make_stream_with_shared(stream_id, user_id, existing.clone()); + entry.has_handle = true; + let shared = entry.shared.clone(); + let stream = self.make_stream_with_shared(stream_id, user_id, shared); return Ok(stream); } let stream = self.make_stream(stream_id, user_id); - self.streams.insert(stream_id, stream.clone_shared()); + self.streams.insert( + stream_id, + StreamEntry { + shared: stream.clone_shared(), + has_handle: true, + }, + ); log::debug!("{}: new stream {}", self.id, stream); @@ -133,7 +147,6 @@ impl StreamRegistry { fn make_implicit_stream_shared(&mut self) -> Arc> { Arc::new(Mutex::new(stream::Shared::new( - State::Open, crate::DEFAULT_CREDIT, crate::DEFAULT_CREDIT, self.accumulated_max_stream_windows.clone(), @@ -143,7 +156,7 @@ impl StreamRegistry { } } -/// A handle for creating streams concurrently. +/// A handle for obtaining streams concurrently. /// /// This type can be cloned and used from multiple tasks while the /// Connection is being polled. @@ -153,11 +166,11 @@ pub struct Handle { } impl Handle { - /// Create a new stream with the given user ID. + /// Get a stream handle for the given user ID. /// /// The stream ID is computed from the user ID using BLAKE3. - pub fn new_stream(&self, user_id: &[u8]) -> Result { - self.registry.lock().new_stream(user_id) + pub fn get_stream(&self, user_id: &[u8]) -> Result { + self.registry.lock().get_stream(user_id) } } @@ -166,8 +179,6 @@ impl Handle { pub(crate) enum StreamCommand { /// A new frame should be sent to the remote. SendFrame(Frame<()>), - /// Close a stream. - CloseStream { stream_id: StreamId }, } /// Possible actions as a result of incoming frame handling. @@ -246,7 +257,7 @@ impl Active { } } - /// Get a handle for creating streams concurrently. + /// Get a handle for obtaining streams concurrently. pub(super) fn handle(&self) -> Handle { Handle { registry: self.registry.clone(), @@ -340,17 +351,9 @@ impl Active { self.pending_write_frame.replace(frame); continue; } - Poll::Ready(Some((_, Some(StreamCommand::CloseStream { stream_id })))) => { - log::trace!("{}/{}: sending close", self.id, stream_id); - self.pending_write_frame - .replace(Frame::close_stream(stream_id).into()); - continue; - } Poll::Ready(Some((id, None))) => { - if let Some(frame) = self.on_drop_stream(id) { - log::trace!("{}/{}: sending: {}", self.id, id, frame.header()); - self.pending_write_frame.replace(frame); - }; + // Handle dropped - transition to buffering or remove + self.on_drop_stream(id); continue; } Poll::Ready(None) => { @@ -387,11 +390,11 @@ impl Active { } } - /// Create a new stream. + /// Get a stream handle for the given user ID. /// /// The stream ID is computed from the user ID using BLAKE3. - pub(super) fn new_stream(&mut self, user_id: &[u8]) -> Result { - let stream = self.registry.lock().new_stream(user_id)?; + pub(super) fn get_stream(&mut self, user_id: &[u8]) -> Result { + let stream = self.registry.lock().get_stream(user_id)?; // Drain new receivers immediately so they're available before poll while let Ok(Some(receiver)) = self.new_receiver_rx.try_next() { self.stream_receivers.push(receiver); @@ -400,37 +403,40 @@ impl Active { } fn on_drop_stream(&mut self, stream_id: StreamId) -> Option> { - let Some(s) = self.registry.lock().streams.remove(&stream_id) else { + let mut registry = self.registry.lock(); + let Some(entry) = registry.streams.get_mut(&stream_id) else { log::warn!("{}: stream {} not found on drop", self.id, stream_id); return None; }; - log::trace!("{}: removing dropped stream {}", self.id, stream_id); - let frame = { - let mut shared = s.lock(); - let frame = match shared.update_state(self.id, stream_id, State::Closed) { - State::Open => { - let mut header = Header::data(stream_id, 0); - header.rst(); - Some(Frame::new(header)) - } - State::RecvClosed => { - let mut header = Header::data(stream_id, 0); - header.fin(); - Some(Frame::new(header)) - } - State::SendClosed => None, - State::Closed => None, - }; - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() + // Mark as no longer having a handle + entry.has_handle = false; + + // Check if buffer is empty - if so, remove from registry + let shared = entry.shared.lock(); + if shared.buffer.is_empty() { + drop(shared); + log::trace!( + "{}: removing dropped stream {} (buffer empty)", + self.id, + stream_id + ); + registry.streams.remove(&stream_id); + } else { + log::trace!( + "{}: stream {} transitioned to buffering", + self.id, + stream_id + ); + // Wake any waiting readers/writers to let them know handle is gone + if let Some(w) = shared.reader.clone() { + drop(shared); + w.wake(); } - frame - }; - frame.map(Into::into) + } + + // No protocol message sent - closing is just dropping the handle + None } fn on_frame(&mut self, frame: Frame<()>) -> Result { @@ -449,20 +455,6 @@ impl Active { let stream_id = frame.header().stream_id(); let mut registry = self.registry.lock(); - if frame.header().flags().contains(header::RST) { - if let Some(s) = registry.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - } - return Action::None; - } - let is_finish = frame.header().flags().contains(header::FIN); // SYN flag on Data frames is not allowed @@ -490,24 +482,37 @@ impl Active { stream_id ); let shared = registry.make_implicit_stream_shared(); - registry.streams.insert(stream_id, shared); + registry.streams.insert( + stream_id, + StreamEntry { + shared, + has_handle: false, + }, + ); } - if let Some(s) = registry.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - if frame.body_len() > shared.receive_window() { + if let Some(entry) = registry.streams.get_mut(&stream_id) { + let mut shared = entry.shared.lock(); + // FIN markers consume 32 bytes of window (size of ChunkOrFin) to prevent FIN + // spam DoS + let fin_cost = if is_finish { 32 } else { 0 }; + let total_cost = frame.body_len() + fin_cost; + if total_cost > shared.receive_window() { log::error!( - "{}/{}: frame body larger than window of stream", + "{}/{}: frame cost {} exceeds window {}", self.id, - stream_id + stream_id, + total_cost, + shared.receive_window() ); return Action::Terminate(Frame::protocol_error()); } + shared.consume_receive_window(total_cost); + shared.buffer.push(frame.into_body()); + // Push FIN marker to buffer (in-order with data) if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); + shared.buffer.push_fin(); } - shared.consume_receive_window(frame.body_len()); - shared.buffer.push(frame.into_body()); if let Some(w) = shared.reader.take() { w.wake() } @@ -520,22 +525,6 @@ impl Active { let stream_id = frame.header().stream_id(); let mut registry = self.registry.lock(); - if frame.header().flags().contains(header::RST) { - if let Some(s) = registry.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - } - return Action::None; - } - - let is_finish = frame.header().flags().contains(header::FIN); - // SYN flag on WindowUpdate frames is not allowed if frame.header().flags().contains(header::SYN) { log::error!("{}: SYN flag on WindowUpdate frame is not allowed", self.id); @@ -559,18 +548,18 @@ impl Active { stream_id ); let shared = registry.make_implicit_stream_shared(); - registry.streams.insert(stream_id, shared); + registry.streams.insert( + stream_id, + StreamEntry { + shared, + has_handle: false, + }, + ); } - if let Some(s) = registry.streams.get_mut(&stream_id) { - let mut shared = s.lock(); + if let Some(entry) = registry.streams.get_mut(&stream_id) { + let mut shared = entry.shared.lock(); shared.increase_send_window_by(frame.header().credit()); - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - if let Some(w) = shared.reader.take() { - w.wake() - } - } if let Some(w) = shared.writer.take() { w.wake() } @@ -603,9 +592,8 @@ impl Active { /// Close and drop all `Stream`s and wake any pending `Waker`s. pub(super) fn drop_all_streams(&mut self) { let mut registry = self.registry.lock(); - for (id, s) in registry.streams.drain() { - let mut shared = s.lock(); - shared.update_state(self.id, id, State::Closed); + for (_id, entry) in registry.streams.drain() { + let mut shared = entry.shared.lock(); if let Some(w) = shared.reader.take() { w.wake() } diff --git a/mux/mux/src/connection/closing.rs b/mux/mux/src/connection/closing.rs index 1edbe28..75fc96e 100644 --- a/mux/mux/src/connection/closing.rs +++ b/mux/mux/src/connection/closing.rs @@ -77,10 +77,6 @@ where Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { this.pending_frames.push_back(frame); } - Poll::Ready(Some((_, Some(StreamCommand::CloseStream { stream_id })))) => { - this.pending_frames - .push_back(Frame::close_stream(stream_id).into()); - } Poll::Ready(Some((_, None))) => {} Poll::Pending | Poll::Ready(None) => { // No more frames from streams, append `Term` frame and flush them all. @@ -211,7 +207,6 @@ mod tests { let frame_data = Frame::data(StreamId::new(b"stream3"), vec![4]) .unwrap() .into(); - let frame_close = Frame::close_stream(StreamId::new(b"stream5")).into(); let frame_term = Frame::term().into(); fn encode(buf: &mut Vec, frame: &Frame<()>) { buf.extend_from_slice(&frame::header::encode(frame.header())); @@ -222,7 +217,6 @@ mod tests { let mut expected_written = vec![]; encode(&mut expected_written, &frame_pending); encode(&mut expected_written, &frame_data); - encode(&mut expected_written, &frame_close); encode(&mut expected_written, &frame_term); let receiver = |frame: &Frame<_>, command: StreamCommand| { @@ -238,12 +232,6 @@ mod tests { &frame_data, StreamCommand::SendFrame(frame_data.clone()), )); - stream_receivers.push(receiver( - &frame_close, - StreamCommand::CloseStream { - stream_id: StreamId::new(b"stream5"), - }, - )); let pending_frames = vec![frame_pending]; let mut socket = Socket { written: vec![], diff --git a/mux/mux/src/connection/stream.rs b/mux/mux/src/connection/stream.rs index ca898ce..94d6d5e 100644 --- a/mux/mux/src/connection/stream.rs +++ b/mux/mux/src/connection/stream.rs @@ -12,7 +12,7 @@ use crate::{ Config, DEFAULT_CREDIT, - chunks::Chunks, + chunks::{ChunkOrFin, Chunks}, connection::{self, StreamCommand, UserId, rtt, rtt::Rtt}, frame::{Frame, header::StreamId}, }; @@ -33,34 +33,9 @@ use std::{ mod flow_control; -/// The state of a stream. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum State { - /// Open bidirectionally. - Open, - /// Open for incoming messages (local sent FIN). - SendClosed, - /// Open for outgoing messages (remote sent FIN). - RecvClosed, - /// Closed (terminal state). - Closed, -} - -impl State { - /// Can we receive messages over this stream? - pub fn can_read(self) -> bool { - matches!(self, State::Open | State::SendClosed) - } - - /// Can we send messages over this stream? - pub fn can_write(self) -> bool { - matches!(self, State::Open | State::RecvClosed) - } -} - /// A multiplexed stream. /// -/// Streams are created via [`crate::Connection::new_stream`]. +/// Stream handles are obtained via [`crate::Connection::get_stream`]. /// /// `Stream` implements [`AsyncRead`] and [`AsyncWrite`]. pub struct Stream { @@ -106,7 +81,6 @@ impl Stream { config: config.clone(), sender, shared: Arc::new(Mutex::new(Shared::new( - State::Open, DEFAULT_CREDIT, DEFAULT_CREDIT, accumulated_max_stream_windows, @@ -146,14 +120,6 @@ impl Stream { self.stream_id } - pub fn is_write_closed(&self) -> bool { - matches!(self.shared().state(), State::SendClosed) - } - - pub fn is_closed(&self) -> bool { - matches!(self.shared().state(), State::Closed) - } - pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { self.shared.lock() } @@ -170,10 +136,6 @@ impl Stream { /// Send new credit to the sending side via a window update message if /// permitted. fn send_window_update(&mut self, cx: &mut Context) -> Poll> { - if !self.shared.lock().state.can_read() { - return Poll::Ready(Ok(())); - } - ready!( self.sender .poll_ready(cx) @@ -212,11 +174,22 @@ impl AsyncRead for Stream { let mut shared = self.shared(); + // Check for FIN marker at front + if matches!(shared.buffer.front(), Some(ChunkOrFin::Fin)) { + shared.buffer.pop(); + log::debug!("{}: eof (FIN marker)", self); + return Poll::Ready(Ok(0)); + } + // Copy data from stream buffer let mut n = 0; - while let Some(chunk) = shared.buffer.front_mut() { + while let Some(chunk) = shared.buffer.front_chunk_mut() { if chunk.is_empty() { shared.buffer.pop(); + // Check if next element is FIN + if matches!(shared.buffer.front(), Some(ChunkOrFin::Fin)) { + break; + } continue; } let k = std::cmp::min(chunk.len(), buf.len() - n); @@ -233,14 +206,15 @@ impl AsyncRead for Stream { return Poll::Ready(Ok(n)); } - // Buffer is empty, check if sender is closed - if !self.config.read_after_close && self.sender.is_closed() { + // Check for FIN marker again (may have been exposed after popping empty chunks) + if matches!(shared.buffer.front(), Some(ChunkOrFin::Fin)) { + shared.buffer.pop(); + log::debug!("{}: eof (FIN marker)", self); return Poll::Ready(Ok(0)); } - // Buffer is empty, check if we can expect to read more data - if !shared.state().can_read() { - log::debug!("{}: eof", self); + // Buffer is empty, check if sender is closed + if !self.config.read_after_close && self.sender.is_closed() { return Poll::Ready(Ok(0)); } @@ -265,10 +239,6 @@ impl AsyncWrite for Stream { let stream_id = self.stream_id; let body = { let mut shared = self.shared(); - if !shared.state().can_write() { - log::debug!("{}: can no longer write", self); - return Poll::Ready(Err(self.write_zero_err())); - } if shared.send_window() == 0 { log::trace!("{}: no more credit left", self); shared.writer = Some(cx.waker().clone()); @@ -303,32 +273,26 @@ impl AsyncWrite for Stream { } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - if self.is_closed() { - return Poll::Ready(Ok(())); - } - ready!( self.sender .poll_ready(cx) .map_err(|_| self.write_zero_err())? ); - log::trace!("{}: close", self); - let cmd = StreamCommand::CloseStream { - stream_id: self.stream_id, - }; + // Send a FIN frame to signal half-close to remote + log::trace!("{}: sending FIN", self); + let frame = Frame::close_stream(self.stream_id); + let cmd = StreamCommand::SendFrame(frame.into()); self.sender .start_send(cmd) .map_err(|_| self.write_zero_err())?; - self.shared() - .update_state(self.conn, self.stream_id, State::SendClosed); + Poll::Ready(Ok(())) } } #[derive(Debug)] pub(crate) struct Shared { - pub(super) state: State, flow_controller: FlowController, pub(crate) buffer: Chunks, pub(crate) reader: Option, @@ -337,7 +301,6 @@ pub(crate) struct Shared { impl Shared { pub(crate) fn new( - initial_state: State, receive_window: u32, send_window: u32, accumulated_max_stream_windows: Arc>, @@ -345,7 +308,6 @@ impl Shared { config: Arc, ) -> Self { Shared { - state: initial_state, flow_controller: FlowController::new( receive_window, send_window, @@ -359,46 +321,6 @@ impl Shared { } } - pub(crate) fn state(&self) -> State { - self.state - } - - /// Update the stream state and return the state before it was updated. - pub(crate) fn update_state( - &mut self, - cid: connection::Id, - sid: StreamId, - next: State, - ) -> State { - use self::State::*; - - let current = self.state; - - match (current, next) { - (Closed, _) => {} - (Open, _) => self.state = next, - (RecvClosed, Closed) => self.state = Closed, - (RecvClosed, Open) => {} - (RecvClosed, RecvClosed) => {} - (RecvClosed, SendClosed) => self.state = Closed, - (SendClosed, Closed) => self.state = Closed, - (SendClosed, Open) => {} - (SendClosed, RecvClosed) => self.state = Closed, - (SendClosed, SendClosed) => {} - } - - log::trace!( - "{}/{}: update state: (from {:?} to {:?} -> {:?})", - cid, - sid, - current, - next, - self.state - ); - - current - } - pub(crate) fn next_window_update(&mut self) -> Option { self.flow_controller.next_window_update(self.buffer.len()) } diff --git a/mux/mux/src/frame/header.rs b/mux/mux/src/frame/header.rs index 0e5b9df..c60006b 100644 --- a/mux/mux/src/frame/header.rs +++ b/mux/mux/src/frame/header.rs @@ -114,13 +114,6 @@ impl Header { } } -impl Header { - /// Set the [`RST`] flag. - pub fn rst(&mut self) { - self.flags.0 |= RST.0 - } -} - impl Header { /// Create a new data frame header. pub fn data(id: StreamId, len: u32) -> Self { @@ -223,11 +216,6 @@ pub trait HasFin: private::Sealed {} impl HasFin for Data {} impl HasFin for WindowUpdate {} -/// Types which have a `rst` method. -pub trait HasRst: private::Sealed {} -impl HasRst for Data {} -impl HasRst for WindowUpdate {} - pub(super) mod private { pub trait Sealed {} @@ -320,9 +308,6 @@ impl Flags { /// Indicates the half-closing of a stream. pub const FIN: Flags = Flags(0x01); -/// Indicates an immediate stream reset. -pub const RST: Flags = Flags(0x02); - /// Indicates a ping request. pub const SYN: Flags = Flags(0x04); diff --git a/mux/test-harness/src/lib.rs b/mux/test-harness/src/lib.rs index 5191637..f71a1d8 100644 --- a/mux/test-harness/src/lib.rs +++ b/mux/test-harness/src/lib.rs @@ -175,7 +175,7 @@ where let mut streams = Vec::with_capacity(nstreams); for i in 0..nstreams { let id = format!("stream-{i}"); - streams.push(conn.new_stream(id.as_bytes()).unwrap()); + streams.push(conn.get_stream(id.as_bytes()).unwrap()); } // Spawn connection poll loop @@ -249,7 +249,7 @@ where let mut streams = Vec::with_capacity(stream_count); for i in 0..stream_count { let id = format!("stream-{i}"); - streams.push(self.conn.new_stream(id.as_bytes())?); + streams.push(self.conn.get_stream(id.as_bytes())?); } // Spawn connection poll loop diff --git a/mux/test-harness/tests/poll_api.rs b/mux/test-harness/tests/poll_api.rs index f0917a6..b97d477 100644 --- a/mux/test-harness/tests/poll_api.rs +++ b/mux/test-harness/tests/poll_api.rs @@ -10,19 +10,12 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use futures::{ - executor::LocalPool, - future, - future::join, - prelude::*, - task::{Spawn, SpawnExt}, - AsyncReadExt, AsyncWriteExt, FutureExt, -}; +use futures::{future, future::join, prelude::*, AsyncReadExt, AsyncWriteExt, FutureExt}; use quickcheck::QuickCheck; -use std::{panic::panic_any, pin::pin}; +use std::{panic::panic_any, pin::pin, time::Duration}; use test_harness::*; use tlsn_mux::{Config, Connection, ConnectionError}; -use tokio::{net::TcpStream, task}; +use tokio::{net::TcpStream, task, time::timeout}; use tokio_util::compat::TokioAsyncReadCompatExt; #[test] @@ -41,30 +34,29 @@ fn prop_config_send_recv_multi() { let socket = listener.accept().await.expect("accept").0.compat(); let mut connection = Connection::new(socket, cfg1); - // Pre-register streams let mut streams = Vec::new(); for i in 0..num_messages { let id = format!("stream-{}", i); - streams.push(connection.new_stream(id.as_bytes()).unwrap()); + streams.push(connection.get_stream(id.as_bytes()).unwrap()); } - // Spawn connection poll loop task::spawn(async move { future::poll_fn(|cx| connection.poll(cx)).await.ok(); }); - // Echo each stream - let mut tasks = Vec::new(); - for mut stream in streams { - tasks.push(task::spawn(async move { - { - let (mut r, mut w) = AsyncReadExt::split(&mut stream); - futures::io::copy(&mut r, &mut w).await?; - } - stream.close().await?; - Ok::<_, ConnectionError>(()) - })); - } + let tasks: Vec<_> = streams + .into_iter() + .map(|mut stream| { + task::spawn(async move { + { + let (mut r, mut w) = AsyncReadExt::split(&mut stream); + futures::io::copy(&mut r, &mut w).await?; + } + stream.close().await?; + Ok::<_, ConnectionError>(()) + }) + }) + .collect(); for task in tasks { task.await.unwrap().unwrap(); @@ -75,34 +67,33 @@ fn prop_config_send_recv_multi() { let socket = TcpStream::connect(address).await.expect("connect").compat(); let mut connection = Connection::new(socket, cfg2); - // Create streams let mut streams = Vec::new(); for i in 0..num_messages { let id = format!("stream-{}", i); - streams.push(connection.new_stream(id.as_bytes()).unwrap()); + streams.push(connection.get_stream(id.as_bytes()).unwrap()); } - // Spawn connection poll loop task::spawn(async move { future::poll_fn(|cx| connection.poll(cx)).await.ok(); }); - // Send/recv on each stream - let mut tasks = Vec::new(); - for (stream, msg) in streams.into_iter().zip(msgs.into_iter()) { - tasks.push(task::spawn(async move { - let mut stream = stream; - send_recv_message(&mut stream, &msg).await.unwrap(); - stream.close().await.unwrap(); - })); - } + let tasks: Vec<_> = streams + .into_iter() + .zip(msgs) + .map(|(mut stream, msg)| { + task::spawn(async move { + send_recv_message(&mut stream, &msg).await.unwrap(); + stream.close().await.unwrap(); + }) + }) + .collect(); for task in tasks { task.await.unwrap(); } }; - futures::future::join(server, client).await; + join(server, client).await; } fn prop(msgs: Vec, TestConfig(cfg1): TestConfig, TestConfig(cfg2): TestConfig) { @@ -134,21 +125,14 @@ fn concurrent_streams() { .await .unwrap(); - // Pre-register streams on server let mut server_streams = Vec::new(); - for i in 0..n_streams { - let id = format!("stream-{}", i); - server_streams.push(server.new_stream(id.as_bytes()).unwrap()); - } - - // Create streams on client let mut client_streams = Vec::new(); for i in 0..n_streams { let id = format!("stream-{}", i); - client_streams.push(client.new_stream(id.as_bytes()).unwrap()); + server_streams.push(server.get_stream(id.as_bytes()).unwrap()); + client_streams.push(client.get_stream(id.as_bytes()).unwrap()); } - // Spawn connection poll loops task::spawn(async move { future::poll_fn(|cx| server.poll(cx)).await.ok(); }); @@ -156,7 +140,6 @@ fn concurrent_streams() { future::poll_fn(|cx| client.poll(cx)).await.ok(); }); - // Server echoes let server_tasks: Vec<_> = server_streams .into_iter() .map(|mut stream| { @@ -171,7 +154,6 @@ fn concurrent_streams() { }) .collect(); - // Client send/recv let client_tasks: Vec<_> = client_streams .into_iter() .map(|mut stream| { @@ -183,12 +165,9 @@ fn concurrent_streams() { }) .collect(); - // Wait for all client tasks for task in client_tasks { task.await.unwrap(); } - - // Wait for all server tasks for task in server_tasks { task.await.unwrap().unwrap(); } @@ -210,26 +189,19 @@ fn prop_max_streams() { async fn run_test(n: usize) -> Result { let max_streams = n % 100; if max_streams == 0 { - return Ok(true); // Skip zero streams + return Ok(true); } let mut cfg = Config::default(); cfg.set_max_num_streams(max_streams); let (mut server, mut client) = connected_peers(cfg.clone(), cfg.clone(), None).await?; - // Pre-register streams on server for i in 0..max_streams { let id = format!("stream-{}", i); - server.new_stream(id.as_bytes())?; + server.get_stream(id.as_bytes())?; + client.get_stream(id.as_bytes())?; } - // Create streams on client - for i in 0..max_streams { - let id = format!("stream-{}", i); - client.new_stream(id.as_bytes())?; - } - - // Spawn connection poll loops task::spawn(async move { future::poll_fn(|cx| server.poll(cx)).await.ok(); }); @@ -237,20 +209,16 @@ fn prop_max_streams() { future::poll_fn(|cx| client.poll(cx)).await.ok(); }); - // Can't open more on a fresh connection since we've already created max streams - // But we need a fresh connection to test this let (mut _server2, mut client2) = connected_peers(cfg.clone(), cfg, None).await?; - // Open max_streams on client2 for i in 0..max_streams { let id = format!("stream-{}", i); - client2.new_stream(id.as_bytes())?; + client2.get_stream(id.as_bytes())?; } - // Try to open one more stream - should fail let extra_id = format!("stream-{}", max_streams); - let open_result = client2.new_stream(extra_id.as_bytes()); - Ok(matches!(open_result, Err(ConnectionError::TooManyStreams))) + let result = client2.get_stream(extra_id.as_bytes()); + Ok(matches!(result, Err(ConnectionError::TooManyStreams))) } fn prop(n: usize) -> Result { @@ -264,63 +232,66 @@ fn prop_max_streams() { QuickCheck::new().tests(7).quickcheck(prop as fn(_) -> _) } +/// Test half-close: client sends FIN, server echoes and sends FIN, client reads +/// echo then EOF. #[test] -fn prop_send_recv_half_closed() { - async fn run_test(msg: Msg) -> Result<(), ConnectionError> { - let msg_len = msg.0.len(); - let stream_id = b"test-stream"; - - let (mut server, mut client) = - connected_peers(Config::default(), Config::default(), None).await?; - - // Create streams before spawning connections - let mut server_stream = server.new_stream(stream_id)?; - let mut client_stream = client.new_stream(stream_id)?; - - // Spawn connection poll loops - task::spawn(async move { - future::poll_fn(|cx| server.poll(cx)).await.ok(); - }); - task::spawn(async move { - future::poll_fn(|cx| client.poll(cx)).await.ok(); - }); - - // Server echoes back - let server_task = task::spawn(async move { - let mut buf = vec![0; msg_len]; - server_stream.read_exact(&mut buf).await?; - server_stream.write_all(&buf).await?; - server_stream.close().await?; - Ok::<_, ConnectionError>(()) - }); +fn half_closed() { + let _ = env_logger::try_init(); + let stream_id = b"half-close"; + let message = b"echo me"; - // Client writes, closes, then reads response - client_stream.write_all(&msg.0).await?; - client_stream.close().await?; + let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(4096, 4096); + let mut server = Connection::new(server_endpoint, Config::default()); + let mut client = Connection::new(client_endpoint, Config::default()); - assert!(client_stream.is_write_closed()); - let mut buf = vec![0; msg_len]; - client_stream.read_exact(&mut buf).await?; + let mut server_stream = server.get_stream(stream_id).unwrap(); + let mut client_stream = client.get_stream(stream_id).unwrap(); - assert_eq!(buf, msg.0); - assert_eq!(Some(0), client_stream.read(&mut buf).await.ok()); - assert!(client_stream.is_closed()); + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); - server_task.await.unwrap()?; + // Client writes and closes (sends FIN) + assert!(pin!(&mut client_stream) + .poll_write(&mut cx, message) + .is_ready()); + assert!(pin!(&mut client_stream).poll_close(&mut cx).is_ready()); - Ok(()) + // Poll to exchange frames + for _ in 0..20 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); } - fn prop(msg: Msg) { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(run_test(msg)) - .unwrap(); + // Server reads the data + let mut buf = [0u8; 7]; + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(7)))); + assert_eq!(&buf, message); + + // Server reads EOF (FIN marker) + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(0)))); + + // Server echoes and closes + assert!(pin!(&mut server_stream) + .poll_write(&mut cx, message) + .is_ready()); + assert!(pin!(&mut server_stream).poll_close(&mut cx).is_ready()); + + // Poll to exchange frames + for _ in 0..20 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); } - QuickCheck::new().tests(7).quickcheck(prop as fn(_) -> _) + // Client reads echo + let result = pin!(&mut client_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(7)))); + assert_eq!(&buf, message); + + // Client reads EOF (FIN marker) + let result = pin!(&mut client_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(0)))); } #[test] @@ -335,11 +306,9 @@ fn prop_config_send_recv_single() { let (mut server, mut client) = connected_peers(cfg1, cfg2, None).await?; - // Create streams before spawning connections - let mut server_stream = server.new_stream(stream_id)?; - let client_stream = client.new_stream(stream_id)?; + let mut server_stream = server.get_stream(stream_id)?; + let client_stream = client.get_stream(stream_id)?; - // Spawn connection poll loops task::spawn(async move { future::poll_fn(|cx| server.poll(cx)).await.ok(); }); @@ -347,7 +316,6 @@ fn prop_config_send_recv_single() { future::poll_fn(|cx| client.poll(cx)).await.ok(); }); - // Server echoes let server_task = task::spawn(async move { { let (mut r, mut w) = AsyncReadExt::split(&mut server_stream); @@ -357,16 +325,12 @@ fn prop_config_send_recv_single() { Ok::<_, ConnectionError>(()) }); - // Client sends all messages send_on_single_stream(client_stream, msgs).await?; - server_task.await.unwrap()?; - Ok(()) } fn prop(msgs: Vec, TestConfig(cfg1): TestConfig, TestConfig(cfg2): TestConfig) { - // Use multi-threaded runtime so task::spawn works tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -380,147 +344,89 @@ fn prop_config_send_recv_single() { .quickcheck(prop as fn(_, _, _) -> _) } -/// This test simulates two endpoints of a multiplexer connection which may be -/// unable to write simultaneously but can make progress by reading. -#[test] -fn write_deadlock() { +#[tokio::test(flavor = "multi_thread")] +async fn write_deadlock() { let _ = env_logger::try_init(); - let mut pool = LocalPool::new(); let msg = vec![1u8; 1024 * 1024]; - let capacity = 1024; let stream_id = b"deadlock-test"; - let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(capacity, capacity); + let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(1024, 1024); - // Create and spawn a "server" that echoes every message back to the client. let mut server = Connection::new(server_endpoint, Config::default()); - let server_stream = server.new_stream(stream_id).unwrap(); - pool.spawner() - .spawn_obj( - async move { - let mut stream = server_stream; - let conn_task = async { - loop { - if future::poll_fn(|cx| server.poll(cx)).await.is_ok() { - break; - } - } - }; - let echo_task = async { - { - let (mut r, mut w) = AsyncReadExt::split(&mut stream); - futures::io::copy(&mut r, &mut w).await.unwrap(); - } - stream.close().await.unwrap(); - }; - futures::select_biased! { - _ = echo_task.fuse() => {}, - _ = conn_task.fuse() => {}, - } - } - .boxed() - .into(), - ) - .unwrap(); - - // Create and spawn a "client" + let mut server_stream = server.get_stream(stream_id).unwrap(); let mut client = Connection::new(client_endpoint, Config::default()); - let stream = client.new_stream(stream_id).unwrap(); - - // Continuously advance the multiplexer connection of the client - pool.spawner() - .spawn_obj( - async move { - loop { - if future::poll_fn(|cx| client.poll(cx)).await.is_ok() { - break; - } + let client_stream = client.get_stream(stream_id).unwrap(); + + task::spawn(async move { + futures::select_biased! { + _ = async { + { + let (mut r, mut w) = AsyncReadExt::split(&mut server_stream); + futures::io::copy(&mut r, &mut w).await.unwrap(); } - } - .boxed() - .into(), + server_stream.close().await.unwrap(); + }.fuse() => {}, + _ = async { future::poll_fn(|cx| server.poll(cx)).await.ok(); }.fuse() => {}, + } + }); + + task::spawn(async move { + future::poll_fn(|cx| client.poll(cx)).await.ok(); + }); + + timeout(Duration::from_secs(10), async { + let (mut reader, mut writer) = AsyncReadExt::split(client_stream); + let mut buf = vec![0; msg.len()]; + let _ = join( + writer.write_all(&msg).map_err(|e| panic_any(e)), + reader.read_exact(&mut buf).map_err(|e| panic_any(e)), ) - .unwrap(); - - // Send the message, expecting it to be echo'd. - pool.run_until( - pool.spawner() - .spawn_with_handle( - async move { - let (mut reader, mut writer) = AsyncReadExt::split(stream); - let mut b = vec![0; msg.len()]; - let _ = join( - writer.write_all(msg.as_ref()).map_err(|e| panic_any(e)), - reader.read_exact(&mut b[..]).map_err(|e| panic_any(e)), - ) - .await; - let mut stream = reader.reunite(writer).unwrap(); - stream.close().await.unwrap(); - log::debug!("C: Stream {:?} done.", stream.id()); - assert_eq!(b, msg); - } - .boxed(), - ) - .unwrap(), - ); + .await; + let mut stream = reader.reunite(writer).unwrap(); + stream.close().await.unwrap(); + assert_eq!(buf, msg); + }) + .await + .expect("timeout"); } +/// Test that data written before dropping a stream handle is still delivered. +/// Note: With the simplified protocol, dropping does NOT send FIN to remote. #[test] -fn close_through_drop_of_stream_propagates_to_remote() { +fn drop_delivers_written_data() { let _ = env_logger::try_init(); - let mut pool = LocalPool::new(); let stream_id = b"drop-test"; let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(1024, 1024); let mut server = Connection::new(server_endpoint, Config::default()); let mut client = Connection::new(client_endpoint, Config::default()); - // Pre-register stream on server - let mut stream_server_side = server.new_stream(stream_id).unwrap(); + let mut server_stream = server.get_stream(stream_id).unwrap(); + let client_stream = client.get_stream(stream_id).unwrap(); - // Spawn client, opening a stream, writing to the stream, dropping the stream - let mut client_stream = client.new_stream(stream_id).unwrap(); - pool.spawner() - .spawn_obj( - async move { - client_stream.write_all(&[42]).await.unwrap(); - drop(client_stream); + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); - loop { - if future::poll_fn(|cx| client.poll(cx)).await.is_ok() { - break; - } - } - } - .boxed() - .into(), - ) - .unwrap(); - - // Spawn server connection state machine. - pool.spawner() - .spawn_obj( - async move { - loop { - if future::poll_fn(|cx| server.poll(cx)).await.is_ok() { - break; - } - } - } - .boxed() - .into(), - ) - .unwrap(); - - // Expect to eventually receive close on stream. - pool.run_until(async { - let mut buf = Vec::new(); - stream_server_side.read_to_end(&mut buf).await?; - assert_eq!(buf, vec![42]); - Ok::<(), std::io::Error>(()) - }) - .unwrap(); + // Client writes and drops (no close/FIN) + assert!(pin!(client_stream).poll_write(&mut cx, &[42]).is_ready()); + // stream dropped here + + // Poll to deliver data + for _ in 0..10 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); + } + + // Server should receive the data + let mut buf = [0u8; 1]; + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(1)))); + assert_eq!(buf[0], 42); + + // Server does NOT get EOF (no FIN was sent) + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(result.is_pending()); } #[test] @@ -536,23 +442,17 @@ fn close_sync() { let waker = std::task::Waker::noop(); let mut cx = std::task::Context::from_waker(waker); - // Create streams on both sides with same ID let stream_id = b"test"; - let client_stream = client.new_stream(stream_id).unwrap(); - let server_stream = server.new_stream(stream_id).unwrap(); + let client_stream = client.get_stream(stream_id).unwrap(); + let server_stream = server.get_stream(stream_id).unwrap(); - // Write from client (this sends StreamInit + Data) assert!(pin!(client_stream).poll_write(&mut cx, b"hello").is_ready()); - - // Client initiates close client.close(); - // Poll client a bunch of times and ensure it doesn't finish closing yet. for _ in 0..10 { assert!(client.poll(&mut cx).is_pending()); } - // Server polls to receive StreamInit and transition the stream let _ = server.poll(&mut cx); let _ = server.poll(&mut cx); @@ -560,13 +460,161 @@ fn close_sync() { assert!(pin!(server_stream).poll_read(&mut cx, &mut buf).is_ready()); assert_eq!(&buf, b"hello"); - // Server polls more to receive GoAway while server.poll(&mut cx).is_pending() {} - // Now server closes server.close(); let _ = server.poll(&mut cx); - // Client should now be able to finish closing while client.poll(&mut cx).is_pending() {} } + +/// Test that after dropping a stream handle, the stream can be reopened +/// and any data buffered during the "closed" period can be read. +#[test] +fn stream_reuse_after_handle_drop() { + let _ = env_logger::try_init(); + let stream_id = b"reuse-test"; + let message = b"buffered data"; + + let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(4096, 4096); + let mut server = Connection::new(server_endpoint, Config::default()); + let mut client = Connection::new(client_endpoint, Config::default()); + + let mut server_stream = server.get_stream(stream_id).unwrap(); + let client_stream = client.get_stream(stream_id).unwrap(); + + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); + + // Drop client stream handle + drop(client_stream); + + // Server writes data + assert!(pin!(&mut server_stream) + .poll_write(&mut cx, message) + .is_ready()); + // Server sends FIN + assert!(pin!(&mut server_stream).poll_close(&mut cx).is_ready()); + + // Poll both connections to exchange frames + for _ in 0..20 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); + } + + // Reopen stream on client + let mut reopened = client.get_stream(stream_id).unwrap(); + + // Read buffered data + let mut buf = [0u8; 14]; + let result = pin!(&mut reopened).poll_read(&mut cx, &mut buf); + assert!(result.is_ready()); + if let std::task::Poll::Ready(Ok(n)) = result { + assert_eq!(n, message.len()); + assert_eq!(&buf[..n], message); + } + + // Read FIN marker (EOF) + let result = pin!(&mut reopened).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(0)))); +} + +/// Test that data can be sent after FIN (FIN is just an in-band marker). +#[test] +fn data_after_fin() { + let _ = env_logger::try_init(); + let stream_id = b"data-after-fin"; + + let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(4096, 4096); + let mut server = Connection::new(server_endpoint, Config::default()); + let mut client = Connection::new(client_endpoint, Config::default()); + + let mut server_stream = server.get_stream(stream_id).unwrap(); + let mut client_stream = client.get_stream(stream_id).unwrap(); + + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); + + // Client sends data + assert!(pin!(&mut client_stream) + .poll_write(&mut cx, b"before") + .is_ready()); + + // Client sends FIN + assert!(pin!(&mut client_stream).poll_close(&mut cx).is_ready()); + + // Client sends more data AFTER FIN + assert!(pin!(&mut client_stream) + .poll_write(&mut cx, b"after") + .is_ready()); + + // Poll to exchange frames + for _ in 0..20 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); + } + + // Server reads "before" + let mut buf = [0u8; 6]; + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(6)))); + assert_eq!(&buf, b"before"); + + // Server reads EOF (FIN marker) + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(0)))); + + // Server reads "after" (data sent after FIN) + let mut buf = [0u8; 5]; + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(5)))); + assert_eq!(&buf, b"after"); +} + +/// Test that streams can be reused multiple times on the same connection. +#[test] +fn stream_reuse_same_connection() { + let _ = env_logger::try_init(); + let stream_id = b"reuse-multi"; + + let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(4096, 4096); + let mut server = Connection::new(server_endpoint, Config::default()); + let mut client = Connection::new(client_endpoint, Config::default()); + + let mut server_stream = server.get_stream(stream_id).unwrap(); + + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); + + // First use: create, write, drop (no FIN sent) + let stream = client.get_stream(stream_id).unwrap(); + assert!(pin!(stream).poll_write(&mut cx, b"first").is_ready()); + // stream dropped here + + // Poll to deliver data + for _ in 0..10 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); + } + + // Reopen same stream ID and write more + let stream = client.get_stream(stream_id).unwrap(); + assert!(pin!(stream).poll_write(&mut cx, b"second").is_ready()); + + // Poll to deliver data + for _ in 0..10 { + let _ = server.poll(&mut cx); + let _ = client.poll(&mut cx); + } + + // Server should receive both messages + let mut buf = [0u8; 5]; + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(5)))); + assert_eq!(&buf, b"first"); + + let mut buf = [0u8; 6]; + let result = pin!(&mut server_stream).poll_read(&mut cx, &mut buf); + assert!(matches!(result, std::task::Poll::Ready(Ok(6)))); + assert_eq!(&buf, b"second"); +}