diff --git a/examples/multicast.rs b/examples/multicast.rs index ea89a2e93..a3bf58ac6 100644 --- a/examples/multicast.rs +++ b/examples/multicast.rs @@ -111,7 +111,7 @@ fn main() { } let socket = sockets.get_mut::(udp_handle); - if !socket.is_open() { + if !socket.is_bound() { socket.bind(MDNS_PORT).unwrap() } diff --git a/examples/server.rs b/examples/server.rs index 33d95c5d5..6f3ddded9 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -96,7 +96,7 @@ fn main() { // udp:6969: respond "hello" let socket = sockets.get_mut::(udp_handle); - if !socket.is_open() { + if !socket.is_bound() { socket.bind(6969).unwrap() } diff --git a/examples/sixlowpan.rs b/examples/sixlowpan.rs index 9d474e3fe..f903286c6 100644 --- a/examples/sixlowpan.rs +++ b/examples/sixlowpan.rs @@ -113,7 +113,7 @@ fn main() { // udp:6969: respond "hello" let socket = sockets.get_mut::(udp_handle); - if !socket.is_open() { + if !socket.is_bound() { socket.bind(6969).unwrap() } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index a5ced968d..ea241235e 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -745,16 +745,12 @@ impl<'a> Socket<'a> { /// Start listening on the given endpoint. /// /// This function returns `Err(Error::Illegal)` if the socket was already open - /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` - /// if the port in the given endpoint is zero. + /// (see [is_open](#method.is_open)). pub fn listen(&mut self, local_endpoint: T) -> Result<(), ListenError> where T: Into, { let local_endpoint = local_endpoint.into(); - if local_endpoint.port == 0 { - return Err(ListenError::Unaddressable); - } if self.is_open() { return Err(ListenError::InvalidState); @@ -1349,7 +1345,9 @@ impl<'a> Socket<'a> { Some(addr) => ip_repr.dst_addr() == addr, None => true, }; - addr_ok && repr.dst_port != 0 && repr.dst_port == self.listen_endpoint.port + addr_ok + && repr.dst_port != 0 + && (self.listen_endpoint.port == 0 || repr.dst_port == self.listen_endpoint.port) } } @@ -1868,7 +1866,10 @@ impl<'a> Socket<'a> { let assembler_was_empty = self.assembler.is_empty(); // Try adding payload octets to the assembler. - let Ok(contig_len) = self.assembler.add_then_remove_front(payload_offset, payload_len) else { + let Ok(contig_len) = self + .assembler + .add_then_remove_front(payload_offset, payload_len) + else { net_debug!( "assembler: too many holes to add {} octets at offset {}", payload_len, @@ -2895,9 +2896,9 @@ mod test { } #[test] - fn test_listen_validation() { + fn test_listen_any_port() { let mut s = socket(); - assert_eq!(s.listen(0), Err(ListenError::Unaddressable)); + assert_eq!(s.listen(0), Ok(())); } #[test] diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 39172dc8e..c526972bd 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -27,6 +27,34 @@ impl> From for UdpMetadata { } } +/// Extended metadata for a sent or received UDP packet. +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ExtendedUdpMetadata { + pub local_endpoint: IpListenEndpoint, + pub remote_endpoint: IpEndpoint, + pub meta: PacketMeta, +} + +impl ExtendedUdpMetadata { + fn new(local_endpoint: IpListenEndpoint, meta: UdpMetadata) -> Self { + Self { + local_endpoint, + remote_endpoint: meta.endpoint, + meta: meta.meta, + } + } +} + +impl From for UdpMetadata { + fn from(value: ExtendedUdpMetadata) -> Self { + Self { + endpoint: value.remote_endpoint, + meta: value.meta, + } + } +} + impl core::fmt::Display for UdpMetadata { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { #[cfg(feature = "packetmeta-id")] @@ -37,11 +65,25 @@ impl core::fmt::Display for UdpMetadata { } } +impl core::fmt::Display for ExtendedUdpMetadata { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[cfg(feature = "packetmeta-id")] + return write!( + f, + "{}/{}, PacketID: {:?}", + self.local_endpoint, self.remote_endpoint, self.meta + ); + + #[cfg(not(feature = "packetmeta-id"))] + write!(f, "{}/{}", self.local_endpoint, self.remote_endpoint) + } +} + /// A UDP packet metadata. -pub type PacketMetadata = crate::storage::PacketMetadata; +pub type PacketMetadata = crate::storage::PacketMetadata; /// A UDP packet ring buffer. -pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, UdpMetadata>; +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ExtendedUdpMetadata>; /// Error returned by [`Socket::bind`] #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -107,7 +149,7 @@ impl std::error::Error for RecvError {} /// packet buffers. #[derive(Debug)] pub struct Socket<'a> { - endpoint: IpListenEndpoint, + bound_endpoint: IpListenEndpoint, rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>, /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. @@ -122,7 +164,7 @@ impl<'a> Socket<'a> { /// Create an UDP socket with the given buffers. pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { Socket { - endpoint: IpListenEndpoint::default(), + bound_endpoint: IpListenEndpoint::default(), rx_buffer, tx_buffer, hop_limit: None, @@ -170,8 +212,8 @@ impl<'a> Socket<'a> { /// Return the bound endpoint. #[inline] - pub fn endpoint(&self) -> IpListenEndpoint { - self.endpoint + pub fn bound_endpoint(&self) -> IpListenEndpoint { + self.bound_endpoint } /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. @@ -203,8 +245,8 @@ impl<'a> Socket<'a> { /// Bind the socket to the given endpoint. /// - /// This function returns `Err(Error::Illegal)` if the socket was open - /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` + /// This function returns `Err(Error::Illegal)` if the socket was bound + /// (see [is_bound](#method.is_bound)), and `Err(Error::Unaddressable)` /// if the port in the given endpoint is zero. pub fn bind>(&mut self, endpoint: T) -> Result<(), BindError> { let endpoint = endpoint.into(); @@ -212,11 +254,11 @@ impl<'a> Socket<'a> { return Err(BindError::Unaddressable); } - if self.is_open() { + if self.is_bound() { return Err(BindError::InvalidState); } - self.endpoint = endpoint; + self.bound_endpoint = endpoint; #[cfg(feature = "async")] { @@ -230,7 +272,7 @@ impl<'a> Socket<'a> { /// Close the socket. pub fn close(&mut self) { // Clear the bound endpoint of the socket. - self.endpoint = IpListenEndpoint::default(); + self.bound_endpoint = IpListenEndpoint::default(); // Reset the RX and TX buffers of the socket. self.tx_buffer.reset(); @@ -245,8 +287,8 @@ impl<'a> Socket<'a> { /// Check whether the socket is open. #[inline] - pub fn is_open(&self) -> bool { - self.endpoint.port != 0 + pub fn is_bound(&self) -> bool { + self.bound_endpoint.port != 0 } /// Check whether the transmit buffer is full. @@ -292,19 +334,19 @@ impl<'a> Socket<'a> { /// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified, /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity /// to ever send this packet. - pub fn send( + pub fn send_from( &mut self, size: usize, - meta: impl Into, + meta: impl Into, ) -> Result<&mut [u8], SendError> { let meta = meta.into(); - if self.endpoint.port == 0 { + if meta.local_endpoint.port == 0 { return Err(SendError::Unaddressable); } - if meta.endpoint.addr.is_unspecified() { + if meta.remote_endpoint.addr.is_unspecified() { return Err(SendError::Unaddressable); } - if meta.endpoint.port == 0 { + if meta.remote_endpoint.port == 0 { return Err(SendError::Unaddressable); } @@ -315,35 +357,53 @@ impl<'a> Socket<'a> { net_trace!( "udp:{}:{}: buffer to send {} octets", - self.endpoint, - meta.endpoint, + meta.local_endpoint, + meta.remote_endpoint, size ); Ok(payload_buf) } + /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer + /// to its payload. + /// + /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full, + /// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified, + /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity + /// to ever send this packet. + pub fn send( + &mut self, + size: usize, + meta: impl Into, + ) -> Result<&mut [u8], SendError> { + self.send_from( + size, + ExtendedUdpMetadata::new(self.bound_endpoint(), meta.into()), + ) + } + /// Enqueue a packet to be send to a given remote endpoint and pass the buffer /// to the provided closure. The closure then returns the size of the data written /// into the buffer. /// /// Also see [send](#method.send). - pub fn send_with( + pub fn send_from_with( &mut self, max_size: usize, - meta: impl Into, + meta: impl Into, f: F, ) -> Result where F: FnOnce(&mut [u8]) -> usize, { let meta = meta.into(); - if self.endpoint.port == 0 { + if meta.local_endpoint.port == 0 { return Err(SendError::Unaddressable); } - if meta.endpoint.addr.is_unspecified() { + if meta.remote_endpoint.addr.is_unspecified() { return Err(SendError::Unaddressable); } - if meta.endpoint.port == 0 { + if meta.remote_endpoint.port == 0 { return Err(SendError::Unaddressable); } @@ -354,13 +414,46 @@ impl<'a> Socket<'a> { net_trace!( "udp:{}:{}: buffer to send {} octets", - self.endpoint, - meta.endpoint, + meta.local_endpoint, + meta.remote_endpoint, size ); Ok(size) } + /// Enqueue a packet to be send to a given remote endpoint and pass the buffer + /// to the provided closure. The closure then returns the size of the data written + /// into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with( + &mut self, + max_size: usize, + meta: impl Into, + f: F, + ) -> Result + where + F: FnOnce(&mut [u8]) -> usize, + { + self.send_from_with( + max_size, + ExtendedUdpMetadata::new(self.bound_endpoint(), meta.into()), + f, + ) + } + + /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice. + /// + /// See also [send](#method.send). + pub fn send_slice_from( + &mut self, + data: &[u8], + meta: impl Into, + ) -> Result<(), SendError> { + self.send_from(data.len(), meta)?.copy_from_slice(data); + Ok(()) + } + /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice. /// /// See also [send](#method.send). @@ -377,48 +470,76 @@ impl<'a> Socket<'a> { /// as a pointer to the payload. /// /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. - pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> { - let (remote_endpoint, payload_buf) = - self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + pub fn recv_to(&mut self) -> Result<(&[u8], ExtendedUdpMetadata), RecvError> { + let (meta, payload_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; net_trace!( "udp:{}:{}: receive {} buffered octets", - self.endpoint, - remote_endpoint.endpoint, + meta.local_endpoint, + meta.remote_endpoint, payload_buf.len() ); - Ok((payload_buf, remote_endpoint)) + Ok((payload_buf, meta)) + } + + /// Dequeue a packet received from a remote endpoint, and return the endpoint as well + /// as a pointer to the payload. + /// + /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> { + self.recv_to().map(|(buf, meta)| (buf, meta.into())) } /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// and return the amount of octets copied as well as the endpoint. /// /// See also [recv](#method.recv). - pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { - let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?; + pub fn recv_slice_to( + &mut self, + data: &mut [u8], + ) -> Result<(usize, ExtendedUdpMetadata), RecvError> { + let (buffer, endpoint) = self.recv_to().map_err(|_| RecvError::Exhausted)?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) } + /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, + /// and return the amount of octets copied as well as the endpoint. + /// + /// See also [recv](#method.recv). + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { + self.recv_slice_to(data) + .map(|(length, meta)| (length, meta.into())) + } + /// Peek at a packet received from a remote endpoint, and return the endpoint as well /// as a pointer to the payload without removing the packet from the receive buffer. /// This function otherwise behaves identically to [recv](#method.recv). /// /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. - pub fn peek(&mut self) -> Result<(&[u8], &UdpMetadata), RecvError> { - let endpoint = self.endpoint; - self.rx_buffer.peek().map_err(|_| RecvError::Exhausted).map( - |(remote_endpoint, payload_buf)| { + pub fn peek_to(&mut self) -> Result<(&[u8], ExtendedUdpMetadata), RecvError> { + self.rx_buffer + .peek() + .map_err(|_| RecvError::Exhausted) + .map(|(meta, payload_buf)| { net_trace!( "udp:{}:{}: peek {} buffered octets", - endpoint, - remote_endpoint.endpoint, + meta.local_endpoint, + meta.remote_endpoint, payload_buf.len() ); - (payload_buf, remote_endpoint) - }, - ) + (payload_buf, *meta) + }) + } + + /// Peek at a packet received from a remote endpoint, and return the endpoint as well + /// as a pointer to the payload without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv](#method.recv). + /// + /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn peek(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> { + self.peek_to().map(|(buf, meta)| (buf, meta.into())) } /// Peek at a packet received from a remote endpoint, copy the payload into the given slice, @@ -427,19 +548,36 @@ impl<'a> Socket<'a> { /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). /// /// See also [peek](#method.peek). - pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> { - let (buffer, endpoint) = self.peek()?; + pub fn peek_slice_to( + &mut self, + data: &mut [u8], + ) -> Result<(usize, ExtendedUdpMetadata), RecvError> { + let (buffer, endpoint) = self.peek_to()?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) } + /// Peek at a packet received from a remote endpoint, copy the payload into the given slice, + /// and return the amount of octets copied as well as the endpoint without removing the + /// packet from the receive buffer. + /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). + /// + /// See also [peek](#method.peek). + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { + self.peek_slice_to(data) + .map(|(length, meta)| (length, meta.into())) + } + pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool { - if self.endpoint.port != repr.dst_port { + if !self.is_bound() { + return true; + } + if self.bound_endpoint.port != repr.dst_port { return false; } - if self.endpoint.addr.is_some() - && self.endpoint.addr != Some(ip_repr.dst_addr()) + if self.bound_endpoint.addr.is_some() + && self.bound_endpoint.addr != Some(ip_repr.dst_addr()) && !cx.is_broadcast(&ip_repr.dst_addr()) && !ip_repr.dst_addr().is_multicast() { @@ -461,6 +599,10 @@ impl<'a> Socket<'a> { let size = payload.len(); + let local_endpoint = IpEndpoint { + addr: ip_repr.dst_addr(), + port: repr.dst_port, + }; let remote_endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port, @@ -468,13 +610,14 @@ impl<'a> Socket<'a> { net_trace!( "udp:{}:{}: receiving {} octets", - self.endpoint, + local_endpoint, remote_endpoint, size ); - let metadata = UdpMetadata { - endpoint: remote_endpoint, + let metadata = ExtendedUdpMetadata { + local_endpoint: local_endpoint.into(), + remote_endpoint, meta, }; @@ -482,7 +625,7 @@ impl<'a> Socket<'a> { Ok(buf) => buf.copy_from_slice(payload), Err(_) => net_trace!( "udp:{}:{}: buffer full, dropped incoming packet", - self.endpoint, + local_endpoint, remote_endpoint ), } @@ -495,19 +638,18 @@ impl<'a> Socket<'a> { where F: FnOnce(&mut Context, PacketMeta, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>, { - let endpoint = self.endpoint; let hop_limit = self.hop_limit.unwrap_or(64); let res = self.tx_buffer.dequeue_with(|packet_meta, payload_buf| { - let src_addr = match endpoint.addr { + let src_addr = match packet_meta.local_endpoint.addr { Some(addr) => addr, - None => match cx.get_source_address(packet_meta.endpoint.addr) { + None => match cx.get_source_address(packet_meta.remote_endpoint.addr) { Some(addr) => addr, None => { net_trace!( "udp:{}:{}: cannot find suitable source address, dropping.", - endpoint, - packet_meta.endpoint + packet_meta.local_endpoint, + packet_meta.remote_endpoint ); return Ok(()); } @@ -516,18 +658,18 @@ impl<'a> Socket<'a> { net_trace!( "udp:{}:{}: sending {} octets", - endpoint, - packet_meta.endpoint, + packet_meta.local_endpoint, + packet_meta.remote_endpoint, payload_buf.len() ); let repr = UdpRepr { - src_port: endpoint.port, - dst_port: packet_meta.endpoint.port, + src_port: packet_meta.local_endpoint.port, + dst_port: packet_meta.remote_endpoint.port, }; let ip_repr = IpRepr::new( src_addr, - packet_meta.endpoint.addr, + packet_meta.remote_endpoint.addr, IpProtocol::Udp, repr.header_len() + payload_buf.len(), hop_limit, @@ -794,7 +936,7 @@ mod test { &REMOTE_UDP_REPR, PAYLOAD, ); - assert_eq!(socket.peek(), Ok((&b"abcdef"[..], &REMOTE_END.into(),))); + assert_eq!(socket.peek(), Ok((&b"abcdef"[..], REMOTE_END.into(),))); assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END.into(),))); assert_eq!(socket.peek(), Err(RecvError::Exhausted)); } @@ -841,7 +983,7 @@ mod test { let mut slice = [0; 4]; assert_eq!( socket.peek_slice(&mut slice[..]), - Ok((4, &REMOTE_END.into())) + Ok((4, REMOTE_END.into())) ); assert_eq!(&slice, b"abcd"); assert_eq!( @@ -943,8 +1085,8 @@ mod test { let mut socket = socket(recv_buffer, buffer(0)); assert_eq!(socket.bind(LOCAL_PORT), Ok(())); - assert!(socket.is_open()); + assert!(socket.is_bound()); socket.close(); - assert!(!socket.is_open()); + assert!(!socket.is_bound()); } }