diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 4798a12111..e4bffe36f3 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -4,7 +4,7 @@ use std::{ convert::TryFrom, fmt, io, mem, net::{IpAddr, SocketAddr}, - sync::Arc, + sync::{mpsc, Arc}, time::{Duration, Instant}, }; @@ -24,14 +24,11 @@ use crate::{ frame::{Close, Datagram, FrameStruct}, packet::{Header, LongType, Packet, PartialDecode, SpaceId}, range_set::ArrayRangeSet, - shared::{ - ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent, - EndpointEventInner, - }, + shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent}, token::ResetToken, transport_parameters::TransportParameters, - Dir, EndpointConfig, Frame, Side, StreamId, Transmit, TransportError, TransportErrorCode, - VarInt, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, TIMER_GRANULARITY, + ConnectionHandle, Dir, EndpointConfig, Frame, Side, StreamId, Transmit, TransportError, + TransportErrorCode, VarInt, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, TIMER_GRANULARITY, }; mod ack_frequency; @@ -131,6 +128,7 @@ pub struct Connection { server_config: Option>, config: Arc, rng: StdRng, + connection_events: mpsc::Receiver, crypto: Box, /// The CID we initially chose, for use during the handshake handshake_cid: ConnectionId, @@ -162,7 +160,7 @@ pub struct Connection { /// Total number of outgoing packets that have been deemed lost lost_packets: u64, events: VecDeque, - endpoint_events: VecDeque, + endpoint_events: EndpointEvents, /// Whether the spin bit is in use for this connection spin_enabled: bool, /// Outgoing spin bit state @@ -253,6 +251,8 @@ impl Connection { version: u32, allow_mtud: bool, rng_seed: [u8; 32], + endpoint_events: EndpointEvents, + connection_events: mpsc::Receiver, ) -> Self { let side = if server_config.is_some() { Side::Server @@ -273,6 +273,7 @@ impl Connection { let mut this = Self { endpoint_config, server_config, + connection_events, crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, @@ -314,7 +315,7 @@ impl Connection { retry_src_cid: None, lost_packets: 0, events: VecDeque::new(), - endpoint_events: VecDeque::new(), + endpoint_events, spin_enabled: config.allow_spin && rng.gen_ratio(7, 8), spin: false, spaces: [initial_space, PacketSpace::new(now), PacketSpace::new(now)], @@ -407,10 +408,10 @@ impl Connection { None } - /// Return endpoint-facing events + /// Whether [`Endpoint::handle_events`] must be called in the immediate future #[must_use] - pub fn poll_endpoint_events(&mut self) -> Option { - self.endpoint_events.pop_front().map(EndpointEvent) + pub fn poll_endpoint_events(&mut self) -> bool { + mem::take(&mut self.endpoint_events.dirty) } /// Provide control over streams @@ -951,14 +952,20 @@ impl Connection { SendableFrames::empty() } - /// Process `ConnectionEvent`s generated by the associated `Endpoint` + /// Process events from the associated [`Endpoint`](crate::Endpoint) /// /// Will execute protocol logic upon receipt of a connection event, in turn preparing signals - /// (including application `Event`s, `EndpointEvent`s and outgoing datagrams) that should be - /// extracted through the relevant methods. - pub fn handle_event(&mut self, event: ConnectionEvent) { - use self::ConnectionEventInner::*; - match event.0 { + /// (including application `Event`s, endpoint events, and outgoing datagrams) that should be + /// checked through the relevant methods. + pub fn handle_events(&mut self, now: Instant) { + while let Ok(event) = self.connection_events.try_recv() { + self.handle_event(event, now); + } + } + + fn handle_event(&mut self, event: ConnectionEvent, now: Instant) { + use self::ConnectionEvent::*; + match event { Datagram { now, remote, @@ -1001,7 +1008,7 @@ impl Connection { self.set_loss_detection_timer(now); } } - NewIdentifiers(ids, now) => { + NewIdentifiers(ids) => { self.local_cid_state.new_cids(&ids, now); ids.into_iter().rev().for_each(|frame| { self.spaces[SpaceId::Data].pending.new_cids.push(frame); @@ -1037,7 +1044,7 @@ impl Connection { match timer { Timer::Close => { self.state = State::Drained; - self.endpoint_events.push_back(EndpointEventInner::Drained); + self.endpoint_events.push(EndpointEvent::Drained); } Timer::Idle => { self.kill(ConnectionError::TimedOut); @@ -1071,7 +1078,7 @@ impl Connection { self.local_cid_state.retire_prior_to() ); self.endpoint_events - .push_back(EndpointEventInner::NeedIdentifiers(now, num_new_cid)); + .push(EndpointEvent::NeedIdentifiers(num_new_cid)); } } Timer::MaxAckDelay => { @@ -2168,7 +2175,7 @@ impl Connection { } } if !was_drained && self.state.is_drained() { - self.endpoint_events.push_back(EndpointEventInner::Drained); + self.endpoint_events.push(EndpointEvent::Drained); // Close timer may have been started previously, e.g. if we sent a close and got a // stateless reset in response self.timers.stop(Timer::Close); @@ -2351,10 +2358,10 @@ impl Connection { } if let Some(token) = params.stateless_reset_token { self.endpoint_events - .push_back(EndpointEventInner::ResetToken(self.path.remote, token)); + .push(EndpointEvent::ResetToken(self.path.remote, token)); } self.handle_peer_params(params)?; - self.issue_first_cids(now); + self.issue_first_cids(); } else { // Server-only self.spaces[SpaceId::Data].pending.handshake_done = true; @@ -2401,7 +2408,7 @@ impl Connection { reason: "transport parameters missing".into(), })?; self.handle_peer_params(params)?; - self.issue_first_cids(now); + self.issue_first_cids(); self.init_0rtt(); } Ok(()) @@ -2661,11 +2668,7 @@ impl Connection { .local_cid_state .on_cid_retirement(sequence, self.peer_params.issue_cids_limit())?; self.endpoint_events - .push_back(EndpointEventInner::RetireConnectionId( - now, - sequence, - allow_more_cids, - )); + .push(EndpointEvent::RetireConnectionId(sequence, allow_more_cids)); } Frame::NewConnectionId(frame) => { trace!( @@ -2881,23 +2884,19 @@ impl Connection { fn set_reset_token(&mut self, reset_token: ResetToken) { self.endpoint_events - .push_back(EndpointEventInner::ResetToken( - self.path.remote, - reset_token, - )); + .push(EndpointEvent::ResetToken(self.path.remote, reset_token)); self.peer_params.stateless_reset_token = Some(reset_token); } /// Issue an initial set of connection IDs to the peer upon connection - fn issue_first_cids(&mut self, now: Instant) { + fn issue_first_cids(&mut self) { if self.local_cid_state.cid_len() == 0 { return; } // Subtract 1 to account for the CID we supplied while handshaking let n = self.peer_params.issue_cids_limit() - 1; - self.endpoint_events - .push_back(EndpointEventInner::NeedIdentifiers(now, n)); + self.endpoint_events.push(EndpointEvent::NeedIdentifiers(n)); } fn populate_packet( @@ -3284,22 +3283,9 @@ impl Connection { /// Decodes a packet, returning its decrypted payload, so it can be inspected in tests #[cfg(test)] - pub(crate) fn decode_packet(&self, event: &ConnectionEvent) -> Option> { - let (first_decode, remaining) = match &event.0 { - ConnectionEventInner::Datagram { - first_decode, - remaining, - .. - } => (first_decode, remaining), - _ => return None, - }; - - if remaining.is_some() { - panic!("Packets should never be coalesced in tests"); - } - + pub(crate) fn decode_packet(&self, packet: PartialDecode) -> Option> { let decrypted_header = packet_crypto::unprotect_header( - first_decode.clone(), + packet.clone(), &self.spaces, self.zero_rtt_crypto.as_ref(), self.peer_params.stateless_reset_token, @@ -3372,10 +3358,9 @@ impl Connection { /// Instruct the peer to replace previously issued CIDs by sending a NEW_CONNECTION_ID frame /// with updated `retire_prior_to` field set to `v` #[cfg(test)] - pub(crate) fn rotate_local_cid(&mut self, v: u64, now: Instant) { + pub(crate) fn rotate_local_cid(&mut self, v: u64) { let n = self.local_cid_state.assign_retire_seq(v); - self.endpoint_events - .push_back(EndpointEventInner::NeedIdentifiers(now, n)); + self.endpoint_events.push(EndpointEvent::NeedIdentifiers(n)); } /// Check the current active remote CID sequence @@ -3416,7 +3401,7 @@ impl Connection { self.close_common(); self.error = Some(reason); self.state = State::Drained; - self.endpoint_events.push_back(EndpointEventInner::Drained); + self.endpoint_events.push(EndpointEvent::Drained); } /// Storage size required for the largest packet known to be supported by the current path @@ -3651,3 +3636,28 @@ impl SentFrames { && self.retransmits.is_empty(streams) } } + +pub(crate) struct EndpointEvents { + ch: ConnectionHandle, + sender: mpsc::Sender<(ConnectionHandle, EndpointEvent)>, + dirty: bool, +} + +impl EndpointEvents { + pub(crate) fn new( + ch: ConnectionHandle, + sender: mpsc::Sender<(ConnectionHandle, EndpointEvent)>, + ) -> Self { + Self { + ch, + sender, + dirty: false, + } + } + + fn push(&mut self, event: EndpointEvent) { + // If the endpoint has gone away, assume the caller is winding down regardless. + _ = self.sender.send((self.ch, event)); + self.dirty = true; + } +} diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 3d61881d15..d3ecc51fdb 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -4,7 +4,7 @@ use std::{ fmt, iter, net::{IpAddr, SocketAddr}, ops::{Index, IndexMut}, - sync::Arc, + sync::{mpsc, Arc}, time::{Instant, SystemTime}, }; @@ -19,14 +19,11 @@ use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, - connection::{Connection, ConnectionError}, + connection::{Connection, ConnectionError, EndpointEvents}, crypto::{self, Keys, UnsupportedVersion}, frame, packet::{Header, Packet, PacketDecodeError, PacketNumber, PartialDecode}, - shared::{ - ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent, - EndpointEventInner, IssuedCid, - }, + shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent, IssuedCid}, transport_parameters::TransportParameters, ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, @@ -45,6 +42,8 @@ pub struct Endpoint { server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets allow_mtud: bool, + event_send: mpsc::Sender<(ConnectionHandle, EndpointEvent)>, + event_recv: mpsc::Receiver<(ConnectionHandle, EndpointEvent)>, } impl Endpoint { @@ -59,6 +58,7 @@ impl Endpoint { allow_mtud: bool, rng_seed: Option<[u8; 32]>, ) -> Self { + let (event_send, event_recv) = mpsc::channel(); Self { rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed), index: ConnectionIndex::default(), @@ -67,6 +67,8 @@ impl Endpoint { config, server_config, allow_mtud, + event_send, + event_recv, } } @@ -75,18 +77,25 @@ impl Endpoint { self.server_config = server_config; } - /// Process `EndpointEvent`s emitted from related `Connection`s + /// Process events from [`Connection`]s that have returned `true` from [`Connection::poll_endpoint_events`] /// - /// In turn, processing this event may return a `ConnectionEvent` for the same `Connection`. - pub fn handle_event( - &mut self, - ch: ConnectionHandle, - event: EndpointEvent, - ) -> Option { - use EndpointEventInner::*; - match event.0 { - NeedIdentifiers(now, n) => { - return Some(self.send_new_identifiers(now, ch, n)); + /// May return the [`ConnectionHandle`] of a [`Connection`] for which + /// [`Connection::handle_events`] must be called. Call until `None` is returned. + pub fn handle_events(&mut self) -> Option { + while let Ok((ch, event)) = self.event_recv.try_recv() { + if self.handle_event(ch, event) { + return Some(ch); + } + } + None + } + + fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) -> bool { + use EndpointEvent::*; + match event { + NeedIdentifiers(n) => { + self.send_new_identifiers(ch, n); + return true; } ResetToken(remote, token) => { if let Some(old) = self.connections[ch].reset_token.replace((remote, token)) { @@ -96,12 +105,13 @@ impl Endpoint { warn!("duplicate reset token"); } } - RetireConnectionId(now, seq, allow_more_cids) => { + RetireConnectionId(seq, allow_more_cids) => { if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) { trace!("peer retired CID {}: {}", seq, cid); self.index.retire(&cid); if allow_more_cids { - return Some(self.send_new_identifiers(now, ch, 1)); + self.send_new_identifiers(ch, 1); + return true; } } } @@ -116,7 +126,27 @@ impl Endpoint { } } } - None + false + } + + #[cfg(test)] + pub(crate) fn decode_packet( + &self, + datagram: BytesMut, + ) -> Result { + PartialDecode::new( + datagram, + self.local_cid_generator.cid_len(), + &self.config.supported_versions, + self.config.grease_quic_bit, + ) + .map(|(packet, rest)| { + assert!( + rest.is_none(), + "capturing decoded coalesced packets in tests is unimplemented" + ); + packet + }) } /// Process an incoming UDP datagram @@ -183,16 +213,16 @@ impl Endpoint { let addresses = FourTuple { remote, local_ip }; if let Some(ch) = self.index.get(&addresses, &first_decode) { - return Some(DatagramEvent::ConnectionEvent( - ch, - ConnectionEvent(ConnectionEventInner::Datagram { + _ = self.connections[ch.0] + .events + .send(ConnectionEvent::Datagram { now, remote: addresses.remote, ecn, first_decode, remaining, - }), - )); + }); + return Some(DatagramEvent::ConnectionEvent(ch)); } // @@ -362,12 +392,7 @@ impl Endpoint { Ok((ch, conn)) } - fn send_new_identifiers( - &mut self, - now: Instant, - ch: ConnectionHandle, - num: u64, - ) -> ConnectionEvent { + fn send_new_identifiers(&mut self, ch: ConnectionHandle, num: u64) { let mut ids = vec![]; for _ in 0..num { let id = self.new_cid(ch); @@ -381,7 +406,9 @@ impl Endpoint { reset_token: ResetToken::new(&*self.config.reset_key, &id), }); } - ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now)) + _ = self.connections[ch] + .events + .send(ConnectionEvent::NewIdentifiers(ids)); } /// Generate a connection ID for `ch` @@ -569,7 +596,7 @@ impl Endpoint { } Err(e) => { debug!("handshake failed: {}", e); - self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained)); + self.handle_event(ch, EndpointEvent::Drained); match e { ConnectionError::TransportError(e) => Some(DatagramEvent::Response( self.initial_close(version, addresses, crypto, &src_cid, e, buf), @@ -595,6 +622,7 @@ impl Endpoint { ) -> Connection { let mut rng_seed = [0; 32]; self.rng.fill_bytes(&mut rng_seed); + let (send, recv) = mpsc::channel(); let conn = Connection::new( self.config.clone(), server_config, @@ -610,6 +638,8 @@ impl Endpoint { version, self.allow_mtud, rng_seed, + EndpointEvents::new(ch, self.event_send.clone()), + recv, ); let id = self.connections.insert(ConnectionMeta { @@ -618,6 +648,7 @@ impl Endpoint { loc_cids: iter::once((0, loc_cid)).collect(), addresses, reset_token: None, + events: send, }); debug_assert_eq!(id, ch.0, "connection handle allocation out of sync"); @@ -827,6 +858,7 @@ pub(crate) struct ConnectionMeta { /// Reset token provided by the peer for the CID we're currently sending to, and the address /// being sent to reset_token: Option<(SocketAddr, ResetToken)>, + events: mpsc::Sender, } /// Internal identifier for a `Connection` currently associated with an endpoint @@ -855,8 +887,8 @@ impl IndexMut for Slab { /// Event resulting from processing a single datagram #[allow(clippy::large_enum_variant)] // Not passed around extensively pub enum DatagramEvent { - /// The datagram is redirected to its `Connection` - ConnectionEvent(ConnectionHandle, ConnectionEvent), + /// [`Connection::handle_events`] must be called on the associated [`Connection`] + ConnectionEvent(ConnectionHandle), /// The datagram has resulted in starting a new `Connection` NewConnection(ConnectionHandle, Connection), /// Response generated directly by the endpoint diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index d271390f02..d823baed6e 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -64,7 +64,7 @@ mod endpoint; pub use crate::endpoint::{ConnectError, ConnectionHandle, DatagramEvent, Endpoint}; mod shared; -pub use crate::shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent}; +pub use crate::shared::{ConnectionId, EcnCodepoint}; mod transport_error; pub use crate::transport_error::{Code as TransportErrorCode, Error as TransportError}; diff --git a/quinn-proto/src/shared.rs b/quinn-proto/src/shared.rs index a8bd274b28..dfbd306a32 100644 --- a/quinn-proto/src/shared.rs +++ b/quinn-proto/src/shared.rs @@ -4,12 +4,8 @@ use bytes::{Buf, BufMut, BytesMut}; use crate::{coding::BufExt, packet::PartialDecode, ResetToken, MAX_CID_SIZE}; -/// Events sent from an Endpoint to a Connection #[derive(Debug)] -pub struct ConnectionEvent(pub(crate) ConnectionEventInner); - -#[derive(Debug)] -pub(crate) enum ConnectionEventInner { +pub(crate) enum ConnectionEvent { /// A datagram has been received for the Connection Datagram { now: Instant, @@ -19,41 +15,20 @@ pub(crate) enum ConnectionEventInner { remaining: Option, }, /// New connection identifiers have been issued for the Connection - NewIdentifiers(Vec, Instant), -} - -/// Events sent from a Connection to an Endpoint -#[derive(Debug)] -pub struct EndpointEvent(pub(crate) EndpointEventInner); - -impl EndpointEvent { - /// Construct an event that indicating that a `Connection` will no longer emit events - /// - /// Useful for notifying an `Endpoint` that a `Connection` has been destroyed outside of the - /// usual state machine flow, e.g. when being dropped by the user. - pub fn drained() -> Self { - Self(EndpointEventInner::Drained) - } - - /// Determine whether this is the last event a `Connection` will emit - /// - /// Useful for determining when connection-related event loop state can be freed. - pub fn is_drained(&self) -> bool { - self.0 == EndpointEventInner::Drained - } + NewIdentifiers(Vec), } #[derive(Clone, Debug, Eq, PartialEq)] -pub(crate) enum EndpointEventInner { +pub(crate) enum EndpointEvent { /// The connection has been drained Drained, /// The reset token and/or address eligible for generating resets has been updated ResetToken(SocketAddr, ResetToken), /// The connection needs connection identifiers - NeedIdentifiers(Instant, u64), + NeedIdentifiers(u64), /// Stop routing connection ID for this sequence number to the connection /// When `bool == true`, a new connection ID will be issued to peer - RetireConnectionId(Instant, u64, bool), + RetireConnectionId(u64, bool), } /// Protocol-level identifier for a connection. diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index c41f3592d0..2c549093b5 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -77,7 +77,7 @@ fn version_negotiate_client() { .unwrap(); let now = Instant::now(); let mut buf = BytesMut::with_capacity(client.config().get_max_udp_payload_size() as usize); - let opt_event = client.handle( + client.handle( now, server_addr, None, @@ -90,9 +90,7 @@ fn version_negotiate_client() { .into(), &mut buf, ); - if let Some(DatagramEvent::ConnectionEvent(_, event)) = opt_event { - client_ch.handle_event(event); - } + client_ch.handle_events(now); assert_matches!( client_ch.poll(), Some(Event::ConnectionLost { @@ -1406,8 +1404,7 @@ fn cid_retirement() { let (client_ch, server_ch) = pair.connect(); // Server retires current active remote CIDs - pair.server_conn_mut(server_ch) - .rotate_local_cid(1, Instant::now()); + pair.server_conn_mut(server_ch).rotate_local_cid(1); pair.drive(); // Any unexpected behavior may trigger TransportError::CONNECTION_ID_LIMIT_ERROR assert!(!pair.client_conn_mut(client_ch).is_closed()); @@ -1423,7 +1420,7 @@ fn cid_retirement() { pair.client_conn_mut(client_ch).ping(); // Server retires all valid remote CIDs pair.server_conn_mut(server_ch) - .rotate_local_cid(next_retire_prior_to, Instant::now()); + .rotate_local_cid(next_retire_prior_to); pair.drive(); assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index cd7bb11239..3ec6e4a1c2 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -1,6 +1,6 @@ use std::{ cmp, - collections::{HashMap, VecDeque}, + collections::{HashMap, HashSet, VecDeque}, env, io::{self, Write}, mem, @@ -289,7 +289,7 @@ pub(super) struct TestEndpoint { pub(super) inbound: VecDeque<(Instant, Option, BytesMut)>, accepted: Option, pub(super) connections: HashMap, - conn_events: HashMap>, + conn_events: HashSet, pub(super) captured_packets: Vec>, pub(super) capture_inbound_packets: bool, } @@ -315,7 +315,7 @@ impl TestEndpoint { inbound: VecDeque::new(), accepted: None, connections: HashMap::default(), - conn_events: HashMap::default(), + conn_events: HashSet::default(), captured_packets: Vec::new(), capture_inbound_packets: false, } @@ -335,22 +335,24 @@ impl TestEndpoint { while self.inbound.front().map_or(false, |x| x.0 <= now) { let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap(); - if let Some(event) = self - .endpoint - .handle(recv_time, remote, None, ecn, packet, &mut buf) + if let Some(event) = + self.endpoint + .handle(recv_time, remote, None, ecn, packet.clone(), &mut buf) { match event { DatagramEvent::NewConnection(ch, conn) => { self.connections.insert(ch, conn); self.accepted = Some(ch); } - DatagramEvent::ConnectionEvent(ch, event) => { + DatagramEvent::ConnectionEvent(ch) => { if self.capture_inbound_packets { - let packet = self.connections[&ch].decode_packet(&event); + let packet = self + .decode_packet(packet) + .ok() + .and_then(|x| self.connections[&ch].decode_packet(x)); self.captured_packets.extend(packet); } - - self.conn_events.entry(ch).or_default().push_back(event); + self.conn_events.insert(ch); } DatagramEvent::Response(transmit) => { let size = transmit.size; @@ -362,22 +364,18 @@ impl TestEndpoint { } loop { - let mut endpoint_events: Vec<(ConnectionHandle, EndpointEvent)> = vec![]; + let mut endpoint_events = false; for (ch, conn) in self.connections.iter_mut() { if self.timeout.map_or(false, |x| x <= now) { self.timeout = None; conn.handle_timeout(now); } - for (_, mut events) in self.conn_events.drain() { - for event in events.drain(..) { - conn.handle_event(event); - } + if self.conn_events.remove(ch) { + conn.handle_events(now); } - while let Some(event) = conn.poll_endpoint_events() { - endpoint_events.push((*ch, event)); - } + endpoint_events |= conn.poll_endpoint_events(); while let Some(transmit) = conn.poll_transmit(now, MAX_DATAGRAMS, &mut buf) { let size = transmit.size; self.outbound @@ -386,15 +384,13 @@ impl TestEndpoint { self.timeout = conn.poll_timeout(); } - if endpoint_events.is_empty() { + if !endpoint_events { break; } - for (ch, event) in endpoint_events { - if let Some(event) = self.handle_event(ch, event) { - if let Some(conn) = self.connections.get_mut(&ch) { - conn.handle_event(event); - } + while let Some(ch) = self.handle_events() { + if let Some(conn) = self.connections.get_mut(&ch) { + conn.handle_events(now); } } } diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index b8bfb20738..b171f12666 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -35,6 +35,7 @@ maintenance = { status = "experimental" } [dependencies] async-io = { version = "2.0", optional = true } async-std = { version = "1.11", optional = true } +atomic-waker = "1.1.2" bytes = "1" # Enables futures::io::{AsyncRead, AsyncWrite} support for streams futures-io = { version = "0.3.19", optional = true } diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 5bed610cb5..527d6c1549 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -10,6 +10,7 @@ use std::{ }; use crate::runtime::{AsyncTimer, AsyncUdpSocket, Runtime}; +use atomic_waker::AtomicWaker; use bytes::{Bytes, BytesMut}; use pin_project_lite::pin_project; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; @@ -40,6 +41,7 @@ impl Connecting { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + endpoint_driver: Arc, conn_events: mpsc::UnboundedReceiver, socket: Arc, runtime: Arc, @@ -50,6 +52,7 @@ impl Connecting { handle, conn, endpoint_events, + endpoint_driver, conn_events, on_handshake_data_send, on_connected_send, @@ -233,7 +236,7 @@ impl Future for ConnectionDriver { // If a timer expires, there might be more to transmit. When we transmit something, we // might need to reset a timer. Hence, we must loop until neither happens. keep_going |= conn.drive_timer(cx); - conn.forward_endpoint_events(); + conn.forward_endpoint_events(&self.0.shared); conn.forward_app_events(&self.0.shared); if !conn.inner.is_drained() { @@ -759,6 +762,7 @@ impl ConnectionRef { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + endpoint_driver: Arc, conn_events: mpsc::UnboundedReceiver, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, @@ -786,7 +790,13 @@ impl ConnectionRef { socket, runtime, }), - shared: Shared::default(), + shared: Shared { + endpoint_driver, + stream_budget_available: Default::default(), + stream_incoming: Default::default(), + datagrams: Default::default(), + closed: Default::default(), + }, })) } @@ -831,7 +841,7 @@ pub(crate) struct ConnectionInner { pub(crate) shared: Shared, } -#[derive(Debug, Default)] +#[derive(Debug)] pub(crate) struct Shared { /// Notified when new streams may be locally initiated due to an increase in stream ID flow /// control budget @@ -840,6 +850,7 @@ pub(crate) struct Shared { stream_incoming: [Notify; 2], datagrams: Notify, closed: Notify, + endpoint_driver: Arc, } pub(crate) struct State { @@ -898,12 +909,16 @@ impl State { false } - fn forward_endpoint_events(&mut self) { - while let Some(event) = self.inner.poll_endpoint_events() { + fn forward_endpoint_events(&mut self, shared: &Shared) { + if !self.inner.poll_endpoint_events() { + return; + } + shared.endpoint_driver.wake(); + if self.inner.is_drained() { // If the endpoint driver is gone, noop. let _ = self .endpoint_events - .send((self.handle, EndpointEvent::Proto(event))); + .send((self.handle, EndpointEvent::Drained)); } } @@ -913,13 +928,14 @@ impl State { shared: &Shared, cx: &mut Context, ) -> Result<(), ConnectionError> { + let now = Instant::now(); loop { match self.conn_events.poll_recv(cx) { Poll::Ready(Some(ConnectionEvent::Ping)) => { self.inner.ping(); } - Poll::Ready(Some(ConnectionEvent::Proto(event))) => { - self.inner.handle_event(event); + Poll::Ready(Some(ConnectionEvent::Proto)) => { + self.inner.handle_events(now); } Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => { self.close(error_code, reason, shared); @@ -1117,10 +1133,9 @@ impl Drop for State { fn drop(&mut self) { if !self.inner.is_drained() { // Ensure the endpoint can tidy up - let _ = self.endpoint_events.send(( - self.handle, - EndpointEvent::Proto(proto::EndpointEvent::drained()), - )); + let _ = self + .endpoint_events + .send((self.handle, EndpointEvent::Drained)); } } } diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 3a6260bee8..d85970dbe1 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -8,11 +8,12 @@ use std::{ pin::Pin, str, sync::{Arc, Mutex}, - task::{Context, Poll, Waker}, + task::{Context, Poll}, time::Instant, }; use crate::runtime::{default_runtime, AsyncUdpSocket, Runtime}; +use atomic_waker::AtomicWaker; use bytes::{Bytes, BytesMut}; use pin_project_lite::pin_project; use proto::{ @@ -197,9 +198,13 @@ impl Endpoint { .connect(Instant::now(), config, addr, server_name)?; let socket = endpoint.socket.clone(); - Ok(endpoint - .connections - .insert(ch, conn, socket, self.runtime.clone())) + Ok(endpoint.connections.insert( + ch, + conn, + socket, + self.runtime.clone(), + self.inner.shared.driver.clone(), + )) } /// Switch to a new UDP socket @@ -319,14 +324,12 @@ impl Future for EndpointDriver { #[allow(unused_mut)] // MSRV fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + self.0.shared.driver.register(cx.waker()); let mut endpoint = self.0.state.lock().unwrap(); - if endpoint.driver.is_none() { - endpoint.driver = Some(cx.waker().clone()); - } let now = Instant::now(); let mut keep_going = false; - keep_going |= endpoint.drive_recv(cx, now)?; + keep_going |= endpoint.drive_recv(cx, now, &self.0.shared)?; keep_going |= endpoint.handle_events(cx, &self.0.shared); keep_going |= endpoint.drive_send(cx)?; @@ -372,7 +375,6 @@ pub(crate) struct State { inner: proto::Endpoint, outgoing: VecDeque, incoming: VecDeque, - driver: Option, ipv6: bool, connections: ConnectionSet, events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>, @@ -387,14 +389,20 @@ pub(crate) struct State { transmit_queue_contents_len: usize, } -#[derive(Debug)] +#[derive(Debug, Default)] pub(crate) struct Shared { incoming: Notify, idle: Notify, + driver: Arc, } impl State { - fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result { + fn drive_recv<'a>( + &'a mut self, + cx: &mut Context, + now: Instant, + shared: &Shared, + ) -> Result { self.recv_limiter.start_cycle(); let mut metas = [RecvMeta::default(); BATCH_SIZE]; let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit(); @@ -432,17 +440,18 @@ impl State { conn, self.socket.clone(), self.runtime.clone(), + shared.driver.clone(), ); self.incoming.push_back(conn); } - Some(DatagramEvent::ConnectionEvent(handle, event)) => { + Some(DatagramEvent::ConnectionEvent(handle)) => { // Ignoring errors from dropped connections that haven't yet been cleaned up let _ = self .connections .senders .get_mut(&handle) .unwrap() - .send(ConnectionEvent::Proto(event)); + .send(ConnectionEvent::Proto); } Some(DatagramEvent::Response(transmit)) => { // Limiting the memory usage for items queued in the outgoing queue from endpoint @@ -527,24 +536,14 @@ impl State { fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool { use EndpointEvent::*; + let mut keep_going = true; for _ in 0..IO_LOOP_BOUND { match self.events.poll_recv(cx) { Poll::Ready(Some((ch, event))) => match event { - Proto(e) => { - if e.is_drained() { - self.connections.senders.remove(&ch); - if self.connections.is_empty() { - shared.idle.notify_waiters(); - } - } - if let Some(event) = self.inner.handle_event(ch, e) { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&ch) - .unwrap() - .send(ConnectionEvent::Proto(event)); + Drained => { + self.connections.senders.remove(&ch); + if self.connections.is_empty() { + shared.idle.notify_waiters(); } } Transmit(t, buf) => { @@ -557,12 +556,27 @@ impl State { }, Poll::Ready(None) => unreachable!("EndpointInner owns one sender"), Poll::Pending => { - return false; + keep_going = false; } } } - true + let mut n = 0; + while let Some(ch) = self.inner.handle_events() { + // Ignoring errors from dropped connections that haven't yet been cleaned up + let _ = self + .connections + .senders + .get_mut(&ch) + .unwrap() + .send(ConnectionEvent::Proto); + n += 1; + if n > IO_LOOP_BOUND { + return true; + } + } + + keep_going } } @@ -612,6 +626,7 @@ impl ConnectionSet { conn: proto::Connection, socket: Arc, runtime: Arc, + driver: Arc, ) -> Connecting { let (send, recv) = mpsc::unbounded_channel(); if let Some((error_code, ref reason)) = self.close { @@ -622,7 +637,15 @@ impl ConnectionSet { .unwrap(); } self.senders.insert(handle, send); - Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime) + Connecting::new( + handle, + conn, + self.sender.clone(), + driver, + recv, + socket, + runtime, + ) } fn is_empty(&self) -> bool { @@ -691,10 +714,7 @@ impl EndpointRef { ]; let (sender, events) = mpsc::unbounded_channel(); Self(Arc::new(EndpointInner { - shared: Shared { - incoming: Notify::new(), - idle: Notify::new(), - }, + shared: Shared::default(), state: Mutex::new(State { socket, inner, @@ -702,7 +722,6 @@ impl EndpointRef { events, outgoing: VecDeque::new(), incoming: VecDeque::new(), - driver: None, connections: ConnectionSet { senders: FxHashMap::default(), sender, @@ -735,9 +754,7 @@ impl Drop for EndpointRef { if x == 0 { // If the driver is about to be on its own, ensure it can shut down if the last // connection is gone. - if let Some(task) = endpoint.driver.take() { - task.wake(); - } + self.0.shared.driver.wake(); } } } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 05221548ae..0d63c70ff2 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -92,13 +92,13 @@ enum ConnectionEvent { error_code: VarInt, reason: bytes::Bytes, }, - Proto(proto::ConnectionEvent), + Proto, Ping, } #[derive(Debug)] enum EndpointEvent { - Proto(proto::EndpointEvent), + Drained, Transmit(proto::Transmit, Bytes), }