diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index a8320a87a6..ce68905742 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -5,15 +5,23 @@ extern crate proto; use libfuzzer_sys::fuzz_target; use proto::{ fuzzing::{PacketParams, PartialDecode}, - FixedLengthConnectionIdParser, DEFAULT_SUPPORTED_VERSIONS, + ConnectionIdParser, RandomConnectionIdGenerator, ZeroLengthConnectionIdParser, + DEFAULT_SUPPORTED_VERSIONS, }; fuzz_target!(|data: PacketParams| { let len = data.buf.len(); let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); + let cid_gen; if let Ok(decoded) = PartialDecode::new( data.buf, - &FixedLengthConnectionIdParser::new(data.local_cid_len), + match data.local_cid_len { + 0 => &ZeroLengthConnectionIdParser as &dyn ConnectionIdParser, + _ => { + cid_gen = RandomConnectionIdGenerator::new(data.local_cid_len); + &cid_gen as &dyn ConnectionIdParser + } + }, &supported_versions, data.grease_quic_bit, ) { diff --git a/quinn-proto/src/cid_generator.rs b/quinn-proto/src/cid_generator.rs index 6173e96f91..73fe14d409 100644 --- a/quinn-proto/src/cid_generator.rs +++ b/quinn-proto/src/cid_generator.rs @@ -1,12 +1,13 @@ use std::{hash::Hasher, time::Duration}; +use bytes::Buf; use rand::{Rng, RngCore}; use crate::shared::ConnectionId; -use crate::MAX_CID_SIZE; +use crate::{ConnectionIdParser, PacketDecodeError, MAX_CID_SIZE}; /// Generates connection IDs for incoming connections -pub trait ConnectionIdGenerator: Send + Sync { +pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser { /// Generates a new CID /// /// Connection IDs MUST NOT contain any information that can be used by @@ -14,17 +15,17 @@ pub trait ConnectionIdGenerator: Send + Sync { /// issuer) to correlate them with other connection IDs for the same /// connection. They MUST have high entropy, e.g. due to encrypted data /// or cryptographic-grade random data. - fn generate_cid(&mut self) -> ConnectionId; + fn generate_cid(&self) -> ConnectionId; /// Quickly determine whether `cid` could have been generated by this generator /// - /// False positives are permitted, but increase the cost of handling invalid packets. + /// False positives are permitted, but increase the cost of handling invalid packets. The input + /// CID is guaranteed to have been obtained from a successful call to the generator's + /// implementation of [`ConnectionIdParser::parse`]. fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> { Ok(()) } - /// Returns the length of a CID for connections created by this generator - fn cid_len(&self) -> usize; /// Returns the lifetime of generated Connection IDs /// /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant. @@ -60,6 +61,10 @@ impl RandomConnectionIdGenerator { /// The given length must be less than or equal to MAX_CID_SIZE. pub fn new(cid_len: usize) -> Self { debug_assert!(cid_len <= MAX_CID_SIZE); + assert!( + cid_len > 0, + "connection ID generators must produce non-empty IDs" + ); Self { cid_len, ..Self::default() @@ -73,19 +78,22 @@ impl RandomConnectionIdGenerator { } } +impl ConnectionIdParser for RandomConnectionIdGenerator { + fn parse(&self, buffer: &mut dyn Buf) -> Result { + (buffer.remaining() >= self.cid_len) + .then(|| ConnectionId::from_buf(buffer, self.cid_len)) + .ok_or(PacketDecodeError::InvalidHeader("packet too small")) + } +} + impl ConnectionIdGenerator for RandomConnectionIdGenerator { - fn generate_cid(&mut self) -> ConnectionId { + fn generate_cid(&self) -> ConnectionId { let mut bytes_arr = [0; MAX_CID_SIZE]; rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]); ConnectionId::new(&bytes_arr[..self.cid_len]) } - /// Provide the length of dst_cid in short header packet - fn cid_len(&self) -> usize { - self.cid_len - } - fn cid_lifetime(&self) -> Option { self.lifetime } @@ -131,9 +139,17 @@ impl Default for HashedConnectionIdGenerator { } } +impl ConnectionIdParser for HashedConnectionIdGenerator { + fn parse(&self, buffer: &mut dyn Buf) -> Result { + (buffer.remaining() >= HASHED_CID_LEN) + .then(|| ConnectionId::from_buf(buffer, HASHED_CID_LEN)) + .ok_or(PacketDecodeError::InvalidHeader("packet too small")) + } +} + impl ConnectionIdGenerator for HashedConnectionIdGenerator { - fn generate_cid(&mut self) -> ConnectionId { - let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN]; + fn generate_cid(&self) -> ConnectionId { + let mut bytes_arr = [0; HASHED_CID_LEN]; rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]); let mut hasher = rustc_hash::FxHasher::default(); hasher.write_u64(self.key); @@ -154,10 +170,6 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator { } } - fn cid_len(&self) -> usize { - NONCE_LEN + SIGNATURE_LEN - } - fn cid_lifetime(&self) -> Option { self.lifetime } @@ -165,6 +177,32 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator { const NONCE_LEN: usize = 3; // Good for more than 16 million connections const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length +const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN; + +/// HACK: Replace uses with `ZeroLengthConnectionIdParser` once [trait upcasting] is stable +/// +/// CID generators should produce nonempty CIDs. We should be able to use +/// `ZeroLengthConnectionIdParser` everywhere this would be needed, but that will require +/// construction of `&dyn ConnectionIdParser` from `&dyn ConnectionIdGenerator`. +/// +/// [trait upcasting]: https://github.com/rust-lang/rust/issues/65991 +pub(crate) struct ZeroLengthConnectionIdGenerator; + +impl ConnectionIdParser for ZeroLengthConnectionIdGenerator { + fn parse(&self, _: &mut dyn Buf) -> Result { + Ok(ConnectionId::new(&[])) + } +} + +impl ConnectionIdGenerator for ZeroLengthConnectionIdGenerator { + fn generate_cid(&self) -> ConnectionId { + unreachable!() + } + + fn cid_lifetime(&self) -> Option { + None + } +} #[cfg(test)] mod tests { @@ -173,7 +211,7 @@ mod tests { #[test] #[cfg(feature = "ring")] fn validate_keyed_cid() { - let mut generator = HashedConnectionIdGenerator::new(); + let generator = HashedConnectionIdGenerator::new(); let cid = generator.generate_cid(); generator.validate(&cid).unwrap(); } diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 6d53fd5905..650d6e2cf3 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -616,11 +616,7 @@ impl Default for MtuDiscoveryConfig { pub struct EndpointConfig { pub(crate) reset_key: Arc, pub(crate) max_udp_payload_size: VarInt, - /// CID generator factory - /// - /// Create a cid generator for local cid in Endpoint struct - pub(crate) connection_id_generator_factory: - Arc Box + Send + Sync>, + pub(crate) connection_id_generator: Option>, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, /// Minimum interval between outgoing stateless reset packets @@ -630,12 +626,10 @@ pub struct EndpointConfig { impl EndpointConfig { /// Create a default config with a particular `reset_key` pub fn new(reset_key: Arc) -> Self { - let cid_factory = - || -> Box { Box::::default() }; Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers - connection_id_generator_factory: Arc::new(cid_factory), + connection_id_generator: Some(Arc::::default()), supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(), grease_quic_bit: true, min_reset_interval: Duration::from_millis(20), @@ -650,11 +644,11 @@ impl EndpointConfig { /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. - pub fn cid_generator Box + Send + Sync + 'static>( + pub fn cid_generator( &mut self, - factory: F, + generator: Option>, ) -> &mut Self { - self.connection_id_generator_factory = Arc::new(factory); + self.connection_id_generator = generator; self } diff --git a/quinn-proto/src/connection/cid_state.rs b/quinn-proto/src/connection/cid_state.rs index abf577ae77..46267ddd74 100644 --- a/quinn-proto/src/connection/cid_state.rs +++ b/quinn-proto/src/connection/cid_state.rs @@ -21,19 +21,12 @@ pub(super) struct CidState { prev_retire_seq: u64, /// Sequence number to set in retire_prior_to field in NEW_CONNECTION_ID frame retire_seq: u64, - /// cid length used to decode short packet - cid_len: usize, //// cid lifetime cid_lifetime: Option, } impl CidState { - pub(crate) fn new( - cid_len: usize, - cid_lifetime: Option, - now: Instant, - issued: u64, - ) -> Self { + pub(crate) fn new(cid_lifetime: Option, now: Instant, issued: u64) -> Self { let mut active_seq = FxHashSet::default(); // Add sequence number of CIDs used in handshaking into tracking set for seq in 0..issued { @@ -45,7 +38,6 @@ impl CidState { active_seq, prev_retire_seq: 0, retire_seq: 0, - cid_len, cid_lifetime, }; // Track lifetime of CIDs used in handshaking @@ -158,11 +150,6 @@ impl CidState { sequence: u64, limit: u64, ) -> Result { - if self.cid_len == 0 { - return Err(TransportError::PROTOCOL_VIOLATION( - "RETIRE_CONNECTION_ID when CIDs aren't in use", - )); - } if sequence > self.issued { debug!( sequence, @@ -181,11 +168,6 @@ impl CidState { Ok(limit > self.active_seq.len() as u64) } - /// Length of local Connection IDs - pub(crate) fn cid_len(&self) -> usize { - self.cid_len - } - /// The value for `retire_prior_to` field in `NEW_CONNECTION_ID` frame pub(crate) fn retire_prior_to(&self) -> u64 { self.retire_seq diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index a0ed445747..69ebbba588 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -15,16 +15,15 @@ use thiserror::Error; use tracing::{debug, error, trace, trace_span, warn}; use crate::{ - cid_generator::ConnectionIdGenerator, + cid_generator::{ConnectionIdGenerator, ZeroLengthConnectionIdGenerator}, cid_queue::CidQueue, coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, - frame, - frame::{Close, Datagram, FrameStruct}, + frame::{self, Close, Datagram, FrameStruct}, packet::{ - FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, - PacketNumber, PartialDecode, SpaceId, + Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, + SpaceId, }, range_set::ArrayRangeSet, shared::{ @@ -197,6 +196,7 @@ pub struct Connection { retry_token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, + cid_gen: Option>, // // Queued non-retransmittable 1-RTT data @@ -230,8 +230,8 @@ pub struct Connection { streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, - // Attributes of CIDs generated by local peer - local_cid_state: CidState, + /// Attributes of CIDs generated by local peer, if in use + local_cid_state: Option, /// State of the unreliable datagram extension datagrams: DatagramState, /// Connection level statistics @@ -252,7 +252,7 @@ impl Connection { remote: SocketAddr, local_ip: Option, crypto: Box, - cid_gen: &dyn ConnectionIdGenerator, + cid_gen: Option>, now: Instant, version: u32, allow_mtud: bool, @@ -280,12 +280,13 @@ impl Connection { crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, - local_cid_state: CidState::new( - cid_gen.cid_len(), - cid_gen.cid_lifetime(), - now, - if pref_addr_cid.is_some() { 2 } else { 1 }, - ), + local_cid_state: cid_gen.as_ref().map(|gen| { + CidState::new( + gen.cid_lifetime(), + now, + if pref_addr_cid.is_some() { 2 } else { 1 }, + ) + }), path: PathData::new(remote, allow_mtud, None, now, path_validated, &config), allow_mtud, local_ip, @@ -329,6 +330,7 @@ impl Connection { }, #[cfg(not(test))] packet_number_filter: PacketNumberFilter::new(&mut rng), + cid_gen, path_responses: PathResponses::default(), close: false, @@ -1086,7 +1088,8 @@ impl Connection { } } NewIdentifiers(ids, now) => { - self.local_cid_state.new_cids(&ids, now); + let cid_state = self.local_cid_state.as_mut().unwrap(); + cid_state.new_cids(&ids, now); ids.into_iter().rev().for_each(|frame| { self.spaces[SpaceId::Data].pending.new_cids.push(frame); }); @@ -1096,7 +1099,9 @@ impl Connection { .get(Timer::PushNewCid) .map_or(true, |x| x <= now) { - self.reset_cid_retirement(); + if let Some(t) = cid_state.next_timeout() { + self.timers.set(Timer::PushNewCid, t); + } } } } @@ -1147,12 +1152,13 @@ impl Connection { } Timer::Pacing => trace!("pacing timer expired"), Timer::PushNewCid => { + let cid_state = self.local_cid_state.as_mut().unwrap(); // Update `retire_prior_to` field in NEW_CONNECTION_ID frame - let num_new_cid = self.local_cid_state.on_cid_timeout().into(); + let num_new_cid = cid_state.on_cid_timeout().into(); if !self.state.is_closed() { trace!( "push a new cid to peer RETIRE_PRIOR_TO field {}", - self.local_cid_state.retire_prior_to() + cid_state.retire_prior_to() ); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, num_new_cid)); @@ -1858,12 +1864,6 @@ impl Connection { self.timers.set(Timer::KeepAlive, now + interval); } - fn reset_cid_retirement(&mut self) { - if let Some(t) = self.local_cid_state.next_timeout() { - self.timers.set(Timer::PushNewCid, t); - } - } - /// Handle the already-decrypted first packet from the client /// /// Decrypting the first packet in the `Endpoint` allows stateless packet handling to be more @@ -2101,7 +2101,10 @@ impl Connection { while let Some(data) = remaining { match PartialDecode::new( data, - &FixedLengthConnectionIdParser::new(self.local_cid_state.cid_len()), + self.cid_gen.as_ref().map_or( + &ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator, + |x| &**x, + ), &[self.version], self.endpoint_config.grease_quic_bit, ) { @@ -2754,8 +2757,12 @@ impl Connection { self.streams.received_stop_sending(id, error_code); } Frame::RetireConnectionId { sequence } => { - let allow_more_cids = self - .local_cid_state + let cid_state = self.local_cid_state.as_mut().ok_or_else(|| { + TransportError::PROTOCOL_VIOLATION( + "RETIRE_CONNECTION_ID when CIDs aren't in use", + ) + })?; + let allow_more_cids = cid_state .on_cid_retirement(sequence, self.peer_params.issue_cids_limit())?; self.endpoint_events .push_back(EndpointEventInner::RetireConnectionId( @@ -2997,7 +3004,7 @@ impl Connection { /// Issue an initial set of connection IDs to the peer upon connection fn issue_first_cids(&mut self, now: Instant) { - if self.local_cid_state.cid_len() == 0 { + if self.local_cid_state.is_none() { return; } @@ -3167,25 +3174,27 @@ impl Connection { } // NEW_CONNECTION_ID - while buf.len() + 44 < max_size { - let issued = match space.pending.new_cids.pop() { - Some(x) => x, - None => break, - }; - trace!( - sequence = issued.sequence, - id = %issued.id, - "NEW_CONNECTION_ID" - ); - frame::NewConnectionId { - sequence: issued.sequence, - retire_prior_to: self.local_cid_state.retire_prior_to(), - id: issued.id, - reset_token: issued.reset_token, + if let Some(cid_state) = self.local_cid_state.as_ref() { + while buf.len() + 44 < max_size { + let issued = match space.pending.new_cids.pop() { + Some(x) => x, + None => break, + }; + trace!( + sequence = issued.sequence, + id = %issued.id, + "NEW_CONNECTION_ID" + ); + frame::NewConnectionId { + sequence: issued.sequence, + retire_prior_to: cid_state.retire_prior_to(), + id: issued.id, + reset_token: issued.reset_token, + } + .encode(buf); + sent.retransmits.get_or_create().new_cids.push(issued); + self.stats.frame_tx.new_connection_id += 1; } - .encode(buf); - sent.retransmits.get_or_create().new_cids.push(issued); - self.stats.frame_tx.new_connection_id += 1; } // RETIRE_CONNECTION_ID @@ -3479,14 +3488,16 @@ impl Connection { #[cfg(test)] pub(crate) fn active_local_cid_seq(&self) -> (u64, u64) { - self.local_cid_state.active_seq() + self.local_cid_state + .as_ref() + .map_or((u64::MAX, u64::MIN), |state| state.active_seq()) } /// Instruct the peer to replace previously issued CIDs by sending a NEW_CONNECTION_ID frame /// with updated `retire_prior_to` field set to `v` #[cfg(test)] pub(crate) fn rotate_local_cid(&mut self, v: u64, now: Instant) { - let n = self.local_cid_state.assign_retire_seq(v); + let n = self.local_cid_state.as_mut().unwrap().assign_retire_seq(v); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, n)); } diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index e4df5787db..d1e483dcf0 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -16,15 +16,17 @@ use thiserror::Error; use tracing::{debug, error, trace, warn}; use crate::{ - cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, + cid_generator::{ + ConnectionIdGenerator, RandomConnectionIdGenerator, ZeroLengthConnectionIdGenerator, + }, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, connection::{Connection, ConnectionError}, crypto::{self, Keys, UnsupportedVersion}, frame, packet::{ - FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, Packet, - PacketDecodeError, PacketNumber, PartialDecode, ProtectedInitialHeader, + Header, InitialHeader, InitialPacket, Packet, PacketDecodeError, PacketNumber, + PartialDecode, ProtectedInitialHeader, }, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, @@ -44,7 +46,7 @@ pub struct Endpoint { rng: StdRng, index: ConnectionIndex, connections: Slab, - local_cid_generator: Box, + local_cid_generator: Option>, config: Arc, server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets @@ -72,7 +74,7 @@ impl Endpoint { rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed), index: ConnectionIndex::default(), connections: Slab::new(), - local_cid_generator: (config.connection_id_generator_factory.as_ref())(), + local_cid_generator: config.connection_id_generator.clone(), config, server_config, allow_mtud, @@ -144,7 +146,10 @@ impl Endpoint { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, - &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()), + self.local_cid_generator.as_ref().map_or( + &ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator, + |x| &**x, + ), &self.config.supported_versions, self.config.grease_quic_bit, ) { @@ -302,8 +307,8 @@ impl Endpoint { if !first_decode.is_initial() && self .local_cid_generator - .validate(first_decode.dst_cid()) - .is_err() + .as_ref() + .map_or(false, |gen| gen.validate(first_decode.dst_cid()).is_err()) { debug!("dropping packet with invalid CID"); return None; @@ -385,9 +390,6 @@ impl Endpoint { remote: SocketAddr, server_name: &str, ) -> Result<(ConnectionHandle, Connection), ConnectError> { - if self.cids_exhausted() { - return Err(ConnectError::CidsExhausted); - } if remote.port() == 0 || remote.ip().is_unspecified() { return Err(ConnectError::InvalidRemoteAddress(remote)); } @@ -403,7 +405,7 @@ impl Endpoint { let params = TransportParameters::new( &config.transport, &self.config, - self.local_cid_generator.as_ref(), + self.local_cid_generator.is_some(), loc_cid, None, ); @@ -456,12 +458,11 @@ impl Endpoint { /// Generate a connection ID for `ch` fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId { loop { - let cid = self.local_cid_generator.generate_cid(); - if cid.len() == 0 { + let Some(cid_generator) = self.local_cid_generator.as_ref() else { // Zero-length CID; nothing to track - debug_assert_eq!(self.local_cid_generator.cid_len(), 0); - return cid; - } + return ConnectionId::EMPTY; + }; + let cid = cid_generator.generate_cid(); if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) { e.insert(ch); break cid; @@ -565,22 +566,6 @@ impl Endpoint { .. } = incoming.packet.header; - if self.cids_exhausted() { - debug!("refusing connection"); - self.index.remove_initial(incoming.orig_dst_cid); - return Err(AcceptError { - cause: ConnectionError::CidsExhausted, - response: Some(self.initial_close( - version, - incoming.addresses, - &incoming.crypto, - &src_cid, - TransportError::CONNECTION_REFUSED(""), - buf, - )), - }); - } - let server_config = server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone()); @@ -608,7 +593,7 @@ impl Endpoint { let mut params = TransportParameters::new( &server_config.transport, &self.config, - self.local_cid_generator.as_ref(), + self.local_cid_generator.is_some(), loc_cid, Some(&server_config), ); @@ -691,7 +676,7 @@ impl Endpoint { header: &ProtectedInitialHeader, ) -> Result<(), TransportError> { let config = &self.server_config.as_ref().unwrap(); - if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming { + if self.incoming_buffers.len() >= config.max_incoming { return Err(TransportError::CONNECTION_REFUSED("")); } @@ -699,10 +684,7 @@ impl Endpoint { // bytes. If this is a Retry packet, then the length must instead match our usual CID // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll // also need to validate CID length for those after decoding the token. - if header.dst_cid.len() < 8 - && (!header.token_pos.is_empty() - && header.dst_cid.len() != self.local_cid_generator.cid_len()) - { + if header.dst_cid.len() < 8 && !header.token_pos.is_empty() { debug!( "rejecting connection due to invalid DCID length {}", header.dst_cid.len() @@ -749,7 +731,10 @@ impl Endpoint { // with established connections. In the unlikely event that a collision occurs // between two connections in the initial phase, both will fail fast and may be // retried by the application layer. - let loc_cid = self.local_cid_generator.generate_cid(); + let loc_cid = self + .local_cid_generator + .as_ref() + .map_or(ConnectionId::EMPTY, |gen| gen.generate_cid()); let token = RetryToken { orig_dst_cid: incoming.packet.header.dst_cid, @@ -833,7 +818,7 @@ impl Endpoint { addresses.remote, addresses.local_ip, tls, - self.local_cid_generator.as_ref(), + self.local_cid_generator.clone(), now, version, self.allow_mtud, @@ -879,7 +864,10 @@ impl Endpoint { // We don't need to worry about CID collisions in initial closes because the peer // shouldn't respond, and if it does, and the CID collides, we'll just drop the // unexpected response. - let local_id = self.local_cid_generator.generate_cid(); + let local_id = self + .local_cid_generator + .as_ref() + .map_or(ConnectionId::EMPTY, |gen| gen.generate_cid()); let number = PacketNumber::U8(0); let header = Header::Initial(InitialHeader { dst_cid: *remote_id, @@ -930,18 +918,6 @@ impl Endpoint { pub(crate) fn known_cids(&self) -> usize { self.index.connection_ids.len() } - - /// Whether we've used up 3/4 of the available CID space - /// - /// We leave some space unused so that `new_cid` can be relied upon to finish quickly. We don't - /// bother to check when CID longer than 4 bytes are used because 2^40 connections is a lot. - fn cids_exhausted(&self) -> bool { - self.local_cid_generator.cid_len() <= 4 - && self.local_cid_generator.cid_len() != 0 - && (2usize.pow(self.local_cid_generator.cid_len() as u32 * 8) - - self.index.connection_ids.len()) - < 2usize.pow(self.local_cid_generator.cid_len() as u32 * 8 - 2) - } } impl fmt::Debug for Endpoint { @@ -1229,11 +1205,6 @@ pub enum ConnectError { /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled. #[error("endpoint stopping")] EndpointStopping, - /// The connection could not be created because not enough of the CID space is available - /// - /// Try using longer connection IDs - #[error("CIDs exhausted")] - CidsExhausted, /// The given server name was malformed #[error("invalid server name: {0}")] InvalidServerName(String), diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index bb39101b78..ae5a4d2c7f 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -66,8 +66,8 @@ pub use crate::endpoint::{ mod packet; pub use packet::{ - ConnectionIdParser, FixedLengthConnectionIdParser, LongType, PacketDecodeError, PartialDecode, - ProtectedHeader, ProtectedInitialHeader, + ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader, + ProtectedInitialHeader, ZeroLengthConnectionIdParser, }; mod shared; diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index be132a74f6..328988e0e3 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -767,32 +767,22 @@ impl PacketNumber { } } -/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length -pub struct FixedLengthConnectionIdParser { - expected_len: usize, -} - -impl FixedLengthConnectionIdParser { - /// Create a new instance of `FixedLengthConnectionIdParser` - pub fn new(expected_len: usize) -> Self { - Self { expected_len } - } -} - -impl ConnectionIdParser for FixedLengthConnectionIdParser { - fn parse(&self, buffer: &mut dyn Buf) -> Result { - (buffer.remaining() >= self.expected_len) - .then(|| ConnectionId::from_buf(buffer, self.expected_len)) - .ok_or(PacketDecodeError::InvalidHeader("packet too small")) - } -} - /// Parse connection id in short header packet pub trait ConnectionIdParser { /// Parse a connection id from given buffer fn parse(&self, buf: &mut dyn Buf) -> Result; } +/// Trivial parser for zero-length connection IDs +pub struct ZeroLengthConnectionIdParser; + +impl ConnectionIdParser for ZeroLengthConnectionIdParser { + #[inline] + fn parse(&self, _: &mut dyn Buf) -> Result { + Ok(ConnectionId::new(&[])) + } +} + /// Long packet type including non-uniform cases #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongHeaderType { @@ -970,7 +960,7 @@ mod tests { let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); let decode = PartialDecode::new( buf.as_slice().into(), - &FixedLengthConnectionIdParser::new(0), + &ZeroLengthConnectionIdParser, &supported_versions, false, ) diff --git a/quinn-proto/src/shared.rs b/quinn-proto/src/shared.rs index 05ffe8caf0..f6dac36522 100644 --- a/quinn-proto/src/shared.rs +++ b/quinn-proto/src/shared.rs @@ -72,6 +72,12 @@ pub struct ConnectionId { } impl ConnectionId { + /// The zero-length connection ID + pub const EMPTY: Self = Self { + len: 0, + bytes: [0; MAX_CID_SIZE], + }; + /// Construct cid from byte array pub fn new(bytes: &[u8]) -> Self { debug_assert!(bytes.len() <= MAX_CID_SIZE); diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 272bf38544..8b4cc01805 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -20,10 +20,8 @@ use tracing::info; use super::*; use crate::{ - cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, - crypto::rustls::QuicServerConfig, - frame::FrameStruct, - transport_parameters::TransportParameters, + cid_generator::RandomConnectionIdGenerator, crypto::rustls::QuicServerConfig, + frame::FrameStruct, transport_parameters::TransportParameters, }; mod util; use util::*; @@ -66,11 +64,9 @@ fn version_negotiate_client() { let server_addr = "[::2]:7890".parse().unwrap(); // Configure client to use empty CIDs so we can easily hardcode a server version negotiation // packet - let cid_generator_factory: fn() -> Box = - || Box::new(RandomConnectionIdGenerator::new(0)); let mut client = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator_factory: Arc::new(cid_generator_factory), + connection_id_generator: None, ..Default::default() }), None, @@ -185,7 +181,7 @@ fn server_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0)))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -215,7 +211,7 @@ fn client_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0)))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -244,7 +240,7 @@ fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); - endpoint_config.cid_generator(move || Box::new(RandomConnectionIdGenerator::new(8))); + endpoint_config.cid_generator(Some(Arc::new(RandomConnectionIdGenerator::new(8)))); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), @@ -1470,11 +1466,9 @@ fn implicit_open() { #[test] fn zero_length_cid() { let _guard = subscribe(); - let cid_generator_factory: fn() -> Box = - || Box::new(RandomConnectionIdGenerator::new(0)); let mut pair = Pair::new( Arc::new(EndpointConfig { - connection_id_generator_factory: Arc::new(cid_generator_factory), + connection_id_generator: None, ..EndpointConfig::default() }), server_config(), @@ -1528,13 +1522,12 @@ fn cid_rotation() { let _guard = subscribe(); const CID_TIMEOUT: Duration = Duration::from_secs(2); - let cid_generator_factory: fn() -> Box = - || Box::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT)); - // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator_factory: Arc::new(cid_generator_factory), + connection_id_generator: Some(Arc::new( + *RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT), + )), ..EndpointConfig::default() }), Some(Arc::new(server_config())), diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index 3a1a443a2b..2dc0a916db 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -15,7 +15,6 @@ use bytes::{Buf, BufMut}; use thiserror::Error; use crate::{ - cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, config::{EndpointConfig, ServerConfig, TransportConfig}, @@ -132,7 +131,7 @@ impl TransportParameters { pub(crate) fn new( config: &TransportConfig, endpoint_config: &EndpointConfig, - cid_gen: &dyn ConnectionIdGenerator, + use_cids: bool, initial_src_cid: ConnectionId, server_config: Option<&ServerConfig>, ) -> Self { @@ -147,7 +146,7 @@ impl TransportParameters { max_udp_payload_size: endpoint_config.max_udp_payload_size, max_idle_timeout: config.max_idle_timeout.unwrap_or(VarInt(0)), disable_active_migration: server_config.map_or(false, |c| !c.migration), - active_connection_id_limit: if cid_gen.cid_len() == 0 { + active_connection_id_limit: if !use_cids { 2 // i.e. default, i.e. unsent } else { CidQueue::LEN as u32 diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index c9c7f768fa..13b84fea54 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -10,7 +10,7 @@ use std::{ use crate::runtime::TokioRuntime; use bytes::Bytes; -use proto::{crypto::rustls::QuicClientConfig, RandomConnectionIdGenerator}; +use proto::crypto::rustls::QuicClientConfig; use rand::{rngs::StdRng, RngCore, SeedableRng}; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, @@ -817,9 +817,7 @@ async fn two_datagram_readers() { async fn multiple_conns_with_zero_length_cids() { let _guard = subscribe(); let mut factory = EndpointFactory::new(); - factory - .endpoint_config - .cid_generator(|| Box::new(RandomConnectionIdGenerator::new(0))); + factory.endpoint_config.cid_generator(None); let server = { let _guard = error_span!("server").entered(); factory.endpoint()