diff --git a/edge-http/README.md b/edge-http/README.md index 721df17..e3a4103 100644 --- a/edge-http/README.md +++ b/edge-http/README.md @@ -104,6 +104,7 @@ async fn request<'b, const N: usize, T: TcpConnect>( ```rust use edge_http::io::server::{Connection, DefaultServer, Handler}; +use edge_http::io::Error; use edge_http::Method; use edge_nal::TcpBind; @@ -130,7 +131,7 @@ pub async fn run(server: &mut DefaultServer) -> Result<(), anyhow::Error> { .bind(addr.parse().unwrap()) .await?; - server.run(acceptor, HttpHandler, None).await?; + server.run(acceptor, HttpHandler, None, None).await?; Ok(()) } @@ -140,9 +141,8 @@ struct HttpHandler; impl<'b, T, const N: usize> Handler<'b, T, N> for HttpHandler where T: Read + Write, - T::Error: Send + Sync + std::error::Error + 'static, { - type Error = anyhow::Error; + type Error = Error; async fn handle(&self, conn: &mut Connection<'b, T, N>) -> Result<(), Self::Error> { let headers = conn.headers()?; diff --git a/edge-http/src/io.rs b/edge-http/src/io.rs index a6d15a1..07cabb3 100644 --- a/edge-http/src/io.rs +++ b/edge-http/src/io.rs @@ -28,7 +28,6 @@ pub enum Error { IncompleteHeaders, IncompleteBody, InvalidState, - Timeout, ConnectionClosed, HeadersMismatchError(HeadersMismatchError), WsUpgradeError(UpgradeError), @@ -87,7 +86,6 @@ where Self::IncompleteHeaders => write!(f, "HTTP headers section is incomplete"), Self::IncompleteBody => write!(f, "HTTP body is incomplete"), Self::InvalidState => write!(f, "Connection is not in requested state"), - Self::Timeout => write!(f, "Timeout"), Self::HeadersMismatchError(e) => write!(f, "Headers mismatch: {e}"), Self::WsUpgradeError(e) => write!(f, "WebSocket upgrade error: {e}"), Self::ConnectionClosed => write!(f, "Connection closed"), diff --git a/edge-http/src/io/client.rs b/edge-http/src/io/client.rs index 6578540..3775e1d 100644 --- a/edge-http/src/io/client.rs +++ b/edge-http/src/io/client.rs @@ -4,7 +4,7 @@ use core::str; use embedded_io_async::{ErrorType, Read, Write}; -use edge_nal::TcpConnect; +use edge_nal::{Close, TcpConnect, TcpShutdown}; use crate::{ ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN}, @@ -39,6 +39,11 @@ where { /// Create a new client connection. /// + /// Note that the connection does not have any built-in read/write timeouts: + /// - To add a timeout on each IO operation, wrap the `socket` type with the `edge_nal::WithTimeout` wrapper. + /// - To add a global request-response timeout, wrap your complete request-response processing + /// logic with the `edge_nal::with_timeout` function. + /// /// Parameters: /// - `buf`: A buffer to use for reading and writing data. /// - `socket`: The TCP stack to use for the connection. @@ -234,11 +239,17 @@ where let mut state = self.unbind(); match result { - Ok(true) | Err(_) => state.io = None, - _ => (), - }; + Ok(true) | Err(_) => { + let mut io = state.io.take().unwrap(); + *self = Self::Unbound(state); - *self = Self::Unbound(state); + io.close(Close::Both).await.map_err(Error::Io)?; + let _ = io.abort().await; + } + _ => { + *self = Self::Unbound(state); + } + }; result?; diff --git a/edge-http/src/io/server.rs b/edge-http/src/io/server.rs index 9476669..1b92a7b 100644 --- a/edge-http/src/io/server.rs +++ b/edge-http/src/io/server.rs @@ -1,11 +1,9 @@ use core::fmt::{self, Debug}; use core::mem::{self, MaybeUninit}; -use core::pin::pin; -use embassy_futures::select::Either; +use edge_nal::{with_timeout, Close, TcpShutdown, WithTimeout, WithTimeoutError}; use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_sync::mutex::Mutex; -use embassy_time::{Duration, Timer}; use embedded_io_async::{ErrorType, Read, Write}; @@ -22,7 +20,8 @@ pub use embedded_svc_compat::*; pub const DEFAULT_HANDLER_TASKS_COUNT: usize = 4; pub const DEFAULT_BUF_SIZE: usize = 2048; -pub const DEFAULT_TIMEOUT_MS: u32 = 5000; +pub const DEFAULT_REQUEST_TIMEOUT_MS: u32 = 30 * 60 * 1000; // 30 minutes +pub const DEFAULT_IO_TIMEOUT_MS: u32 = 50 * 1000; // 50 seconds const COMPLETION_BUF_SIZE: usize = 64; @@ -41,30 +40,21 @@ where { /// Create a new connection state machine for an incoming request /// + /// Note that the connection does not have any built-in read/write timeouts: + /// - To add a timeout on each IO operation, wrap the `io` type with the `edge_nal::WithTimeout` wrapper. + /// - To add a global request-response timeout, wrap your complete request-response processing + /// logic with the `edge_nal::with_timeout` function. + /// /// Parameters: /// - `buf`: A buffer to store the request headers /// - `io`: A socket stream - /// - `timeout_ms`: An optional timeout in milliseconds to wait for a new incoming request pub async fn new( buf: &'b mut [u8], mut io: T, - timeout_ms: Option, ) -> Result, Error> { let mut request = RequestHeaders::new(); - let (buf, read_len) = { - let timeout_ms = timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS); - - let receive = pin!(request.receive(buf, &mut io, true)); - let timer = Timer::after(Duration::from_millis(timeout_ms as _)); - - let result = embassy_futures::select::select(receive, timer).await; - - match result { - Either::First(result) => result, - Either::Second(_) => Err(Error::Timeout), - }? - }; + let (buf, read_len) = request.receive(buf, &mut io, true).await?; let (connection_type, body_type) = request.resolve::()?; @@ -433,16 +423,30 @@ where /// /// The socket stream will be closed only in case of error, or until the client explicitly requests that /// either with a hard socket close, or with a `Connection: Close` header. +/// +/// Parameters: +/// - `io`: A socket stream +/// - `buf`: A work-area buffer used by the implementation +/// - `request_timeout_ms`: An optional timeout for a complete request-response processing, in milliseconds. +/// If not provided, a default timeout of 30 minutes is used. +/// - `handler`: An implementation of `Handler` to handle incoming requests pub async fn handle_connection( io: T, buf: &mut [u8], - timeout_ms: Option, + request_timeout_ms: Option, handler: H, ) where H: for<'b> Handler<'b, &'b mut T, N>, - T: Read + Write, + T: Read + Write + TcpShutdown, { - handle_task_connection(io, buf, timeout_ms, 0, TaskHandlerAdaptor::new(handler)).await + handle_task_connection( + io, + buf, + request_timeout_ms, + 0, + TaskHandlerAdaptor::new(handler), + ) + .await } /// A convenience function to handle multiple HTTP requests over a single socket stream, @@ -450,44 +454,63 @@ pub async fn handle_connection( /// /// The socket stream will be closed only in case of error, or until the client explicitly requests that /// either with a hard socket close, or with a `Connection: Close` header. +/// +/// Parameters: +/// - `io`: A socket stream +/// - `buf`: A work-area buffer used by the implementation +/// - `request_timeout_ms`: An optional timeout for a complete request-response processing, in milliseconds. +/// If not provided, a default timeout of 30 minutes is used. +/// - `task_id`: An identifier for the task, used for logging purposes +/// - `handler`: An implementation of `TaskHandler` to handle incoming requests pub async fn handle_task_connection( mut io: T, buf: &mut [u8], - timeout_ms: Option, + request_timeout_ms: Option, task_id: usize, handler: H, ) where H: for<'b> TaskHandler<'b, &'b mut T, N>, - T: Read + Write, + T: Read + Write + TcpShutdown, { - loop { + let close = loop { debug!("Handler task {task_id}: Waiting for new request"); - let result = - handle_task_request::(buf, &mut io, task_id, timeout_ms, &handler).await; + let result = with_timeout( + request_timeout_ms.unwrap_or(DEFAULT_REQUEST_TIMEOUT_MS), + handle_task_request::(buf, &mut io, task_id, &handler), + ) + .await; match result { - Err(HandleRequestError::Connection(Error::Timeout)) => { - info!("Handler task {task_id}: Connection closed due to timeout"); - break; + Err(WithTimeoutError::Timeout) => { + info!("Handler task {task_id}: Connection closed due to request timeout"); + break false; } - Err(HandleRequestError::Connection(Error::ConnectionClosed)) => { + Err(WithTimeoutError::IO(HandleRequestError::Connection(Error::ConnectionClosed))) => { debug!("Handler task {task_id}: Connection closed"); - break; + break false; } Err(e) => { warn!("Handler task {task_id}: Error when handling request: {e:?}"); - break; + break true; } Ok(needs_close) => { if needs_close { debug!("Handler task {task_id}: Request complete; closing connection"); - break; + break true; } else { debug!("Handler task {task_id}: Request complete"); } } } + }; + + if close { + if let Err(e) = io.close(Close::Both).await { + warn!("Handler task {task_id}: Error when closing the socket: {e:?}"); + } + } else { + let _ = io.abort().await; } } @@ -519,6 +542,19 @@ where } } +impl embedded_io_async::Error for HandleRequestError +where + C: Debug + embedded_io_async::Error, + E: Debug, +{ + fn kind(&self) -> embedded_io_async::ErrorKind { + match self { + Self::Connection(Error::Io(e)) => e.kind(), + _ => embedded_io_async::ErrorKind::Other, + } + } +} + #[cfg(feature = "std")] impl std::error::Error for HandleRequestError where @@ -529,33 +565,52 @@ where /// A convenience function to handle a single HTTP request over a socket stream, /// using the specified handler. +/// +/// Note that this function does not set any timeouts on the request-response processing +/// or on the IO operations. It is up that the caller to use the `with_timeout` function +/// and the `WithTimeout` struct from the `edge-nal` crate to wrap the future returned +/// by this function, or the socket stream, or both. +/// +/// Parameters: +/// - `buf`: A work-area buffer used by the implementation +/// - `io`: A socket stream +/// - `handler`: An implementation of `Handler` to handle incoming requests pub async fn handle_request<'b, const N: usize, H, T>( buf: &'b mut [u8], io: T, - timeout_ms: Option, handler: H, ) -> Result> where H: Handler<'b, T, N>, T: Read + Write, { - handle_task_request(buf, io, 0, timeout_ms, TaskHandlerAdaptor::new(handler)).await + handle_task_request(buf, io, 0, TaskHandlerAdaptor::new(handler)).await } /// A convenience function to handle a single HTTP request over a socket stream, /// using the specified task handler. +/// +/// Note that this function does not set any timeouts on the request-response processing +/// or on the IO operations. It is up that the caller to use the `with_timeout` function +/// and the `WithTimeout` struct from the `edge-nal` crate to wrap the future returned +/// by this function, or the socket stream, or both. +/// +/// Parameters: +/// - `buf`: A work-area buffer used by the implementation +/// - `io`: A socket stream +/// - `task_id`: An identifier for the task, used for logging purposes +/// - `handler`: An implementation of `TaskHandler` to handle incoming requests pub async fn handle_task_request<'b, const N: usize, H, T>( buf: &'b mut [u8], io: T, task_id: usize, - timeout_ms: Option, handler: H, ) -> Result> where H: TaskHandler<'b, T, N>, T: Read + Write, { - let mut connection = Connection::<_, N>::new(buf, io, timeout_ms).await?; + let mut connection = Connection::<_, N>::new(buf, io).await?; let result = handler.handle(task_id, &mut connection).await; @@ -595,17 +650,26 @@ impl Server { } /// Run the server with the specified acceptor and handler + /// + /// Parameters: + /// - `acceptor`: An implementation of `edge_nal::TcpAccept` to accept incoming connections + /// - `handler`: An implementation of `Handler` to handle incoming requests + /// - `request_timeout_ms`: An optional timeout for a complete request-response processing, in milliseconds. + /// If not provided, a default timeout of 30 minutes is used. + /// - `io_timeout_ms`: An optional timeout for each IO operation, in milliseconds. + /// If not provided, a default timeout of 50 seconds is used. #[inline(never)] #[cold] pub async fn run( &mut self, acceptor: A, handler: H, - timeout_ms: Option, + request_timeout_ms: Option, + io_timeout_ms: Option, ) -> Result<(), Error> where A: edge_nal::TcpAccept, - H: for<'b, 't> Handler<'b, &'b mut A::Socket<'t>, N>, + H: for<'b, 't> Handler<'b, &'b mut WithTimeout>, N>, { let handler = TaskHandlerAdaptor::new(handler); @@ -637,12 +701,15 @@ impl Server { acceptor.accept().await.map_err(Error::Io)?.1 }; + let io = + WithTimeout::new(io_timeout_ms.unwrap_or(DEFAULT_IO_TIMEOUT_MS), io); + debug!("Handler task {task_id}: Got connection request"); handle_task_connection::( io, unsafe { buf.as_mut() }.unwrap(), - timeout_ms, + request_timeout_ms, task_id, handler, ) @@ -661,17 +728,26 @@ impl Server { } /// Run the server with the specified acceptor and task handler + /// + /// Parameters: + /// - `acceptor`: An implementation of `edge_nal::TcpAccept` to accept incoming connections + /// - `handler`: An implementation of `TaskHandler` to handle incoming requests + /// - `request_timeout_ms`: An optional timeout for a complete request-response processing, in milliseconds. + /// If not provided, a default timeout of 30 minutes is used. + /// - `io_timeout_ms`: An optional timeout for each IO operation, in milliseconds. + /// If not provided, a default timeout of 50 seconds is used. #[inline(never)] #[cold] pub async fn run_with_task_id( &mut self, acceptor: A, handler: H, - timeout_ms: Option, + request_timeout_ms: Option, + io_timeout_ms: Option, ) -> Result<(), Error> where A: edge_nal::TcpAccept, - H: for<'b, 't> TaskHandler<'b, &'b mut A::Socket<'t>, N>, + H: for<'b, 't> TaskHandler<'b, &'b mut WithTimeout>, N>, { let mutex = Mutex::::new(()); let mut tasks = heapless::Vec::<_, P>::new(); @@ -699,12 +775,15 @@ impl Server { acceptor.accept().await.map_err(Error::Io)?.1 }; + let io = + WithTimeout::new(io_timeout_ms.unwrap_or(DEFAULT_IO_TIMEOUT_MS), io); + debug!("Handler task {task_id}: Got connection request"); handle_task_connection::( io, unsafe { buf.as_mut() }.unwrap(), - timeout_ms, + request_timeout_ms, task_id, handler, ) diff --git a/edge-nal-embassy/Cargo.toml b/edge-nal-embassy/Cargo.toml index 80c7f70..0a33735 100644 --- a/edge-nal-embassy/Cargo.toml +++ b/edge-nal-embassy/Cargo.toml @@ -18,5 +18,6 @@ categories = [ embedded-io-async = { workspace = true } edge-nal = { workspace = true } heapless = { workspace = true } -# Do not require these features and conditionalize the code instead +# TODO: Do not require these features and conditionalize the code instead embassy-net = { version = "0.4", features = ["tcp", "udp", "dns", "proto-ipv6", "medium-ethernet", "proto-ipv4", "igmp"] } +embassy-futures = { workspace = true } diff --git a/edge-nal-embassy/README.md b/edge-nal-embassy/README.md index add9714..3df5501 100644 --- a/edge-nal-embassy/README.md +++ b/edge-nal-embassy/README.md @@ -10,11 +10,13 @@ A bare-metal implementation of `edge-nal` based on the [embassy-net](https://cra ### TCP -All traits. +All traits except `Readable` which - while implemented - panics if called. ### UDP -* All traits except `UdpConnect` and `Multicast`. +* All traits except `UdpConnect`. +* `MulticastV6` - while implemented - panics if `join_v6` / `leave_v6` are called. +* `Readable` - while implemented - panics if called. ### Raw sockets diff --git a/edge-nal-embassy/src/tcp.rs b/edge-nal-embassy/src/tcp.rs index 4a0c486..8e86906 100644 --- a/edge-nal-embassy/src/tcp.rs +++ b/edge-nal-embassy/src/tcp.rs @@ -1,7 +1,10 @@ use core::net::SocketAddr; +use core::pin::pin; use core::ptr::NonNull; -use edge_nal::{Readable, TcpBind, TcpConnect, TcpSplit}; +use edge_nal::{Close, Readable, TcpBind, TcpConnect, TcpShutdown, TcpSplit}; + +use embassy_futures::join::join; use embassy_net::driver::Driver; use embassy_net::tcp::{AcceptError, ConnectError, Error, TcpReader, TcpWriter}; @@ -126,6 +129,46 @@ impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpSocket<'d, N socket_buffers, }) } + + async fn close(&mut self, what: Close) -> Result<(), TcpError> { + async fn discard_all_data(rx: &mut TcpReader<'_>) -> Result<(), TcpError> { + let mut buf = [0; 32]; + + while rx.read(&mut buf).await? > 0 {} + + Ok(()) + } + + if matches!(what, Close::Both | Close::Write) { + self.socket.close(); + } + + let (mut rx, mut tx) = self.socket.split(); + + match what { + Close::Read => discard_all_data(&mut rx).await?, + Close::Write => tx.flush().await?, + Close::Both => { + let mut flush = pin!(tx.flush()); + let mut read = pin!(discard_all_data(&mut rx)); + + match join(&mut flush, &mut read).await { + (Err(e), _) => Err(e)?, + (_, Err(e)) => Err(e)?, + _ => (), + } + } + } + + Ok(()) + } + + async fn abort(&mut self) -> Result<(), TcpError> { + self.socket.abort(); + self.socket.flush().await?; + + Ok(()) + } } impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> Drop @@ -175,6 +218,18 @@ impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> Readable } } +impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpShutdown + for TcpSocket<'d, N, TX_SZ, RX_SZ> +{ + async fn close(&mut self, what: Close) -> Result<(), Self::Error> { + TcpSocket::close(self, what).await + } + + async fn abort(&mut self) -> Result<(), Self::Error> { + TcpSocket::abort(self).await + } +} + /// Represents the read half of a split TCP socket /// Implements the `Read` trait from `embedded-io-async` pub struct TcpSocketRead<'a>(TcpReader<'a>); diff --git a/edge-nal-std/src/lib.rs b/edge-nal-std/src/lib.rs index 64db480..72b89b7 100644 --- a/edge-nal-std/src/lib.rs +++ b/edge-nal-std/src/lib.rs @@ -6,7 +6,7 @@ use core::ops::Deref; use core::pin::pin; use std::io; -use std::net::{self, TcpStream, ToSocketAddrs, UdpSocket as StdUdpSocket}; +use std::net::{self, Shutdown, TcpStream, ToSocketAddrs, UdpSocket as StdUdpSocket}; #[cfg(not(feature = "async-io-mini"))] use async_io::Async; @@ -18,8 +18,8 @@ use futures_lite::io::{AsyncReadExt, AsyncWriteExt}; use embedded_io_async::{ErrorType, Read, Write}; use edge_nal::{ - AddrType, Dns, MulticastV4, MulticastV6, Readable, TcpAccept, TcpBind, TcpConnect, TcpSplit, - UdpBind, UdpConnect, UdpReceive, UdpSend, UdpSplit, + AddrType, Dns, MulticastV4, MulticastV6, Readable, TcpAccept, TcpBind, TcpConnect, TcpShutdown, + TcpSplit, UdpBind, UdpConnect, UdpReceive, UdpSend, UdpSplit, }; #[cfg(all(unix, not(target_os = "espidf")))] @@ -101,7 +101,7 @@ impl TcpAccept for TcpAcceptor { match self.0.as_ref().accept() { Ok((socket, _)) => break Ok((socket.peer_addr()?, TcpSocket(Async::new(socket)?))), Err(err) if err.kind() == io::ErrorKind::WouldBlock => { - async_io::Timer::after(core::time::Duration::from_millis(5)).await; + async_io::Timer::after(core::time::Duration::from_millis(20)).await; } Err(err) => break Err(err), } @@ -199,6 +199,24 @@ impl TcpSplit for TcpSocket { } } +impl TcpShutdown for TcpSocket { + async fn close(&mut self, what: edge_nal::Close) -> Result<(), Self::Error> { + match what { + edge_nal::Close::Read => self.0.as_ref().shutdown(Shutdown::Read)?, + edge_nal::Close::Write => self.0.as_ref().shutdown(Shutdown::Write)?, + edge_nal::Close::Both => self.0.as_ref().shutdown(Shutdown::Both)?, + } + + Ok(()) + } + + async fn abort(&mut self) -> Result<(), Self::Error> { + // No-op, STD will abort the socket on drop anyway + + Ok(()) + } +} + impl UdpConnect for Stack { type Error = io::Error; diff --git a/edge-nal/Cargo.toml b/edge-nal/Cargo.toml index f9a864e..f46c9e0 100644 --- a/edge-nal/Cargo.toml +++ b/edge-nal/Cargo.toml @@ -16,3 +16,4 @@ categories = [ [dependencies] embedded-io-async = { workspace = true } +embassy-time = { workspace = true } diff --git a/edge-nal/README.md b/edge-nal/README.md index 0937605..8e962a1 100644 --- a/edge-nal/README.md +++ b/edge-nal/README.md @@ -12,6 +12,7 @@ Hosts a bunch of networking (UDP, TCP and raw ethernet) traits. * Factory traits for the creation of TCP server sockets - `TcpBind` and `TcpAccept`. `embedded-nal-async` only has `TcpConnect` * Splittable sockets with `TcpSplit` (can be optionally implemented by `TcpConnect` and `TcpAccept`) +* Socket shutdown with `TcpShutdown` ### UDP @@ -20,7 +21,7 @@ Hosts a bunch of networking (UDP, TCP and raw ethernet) traits. * Returning the local address of a UDP socket bind / connect operation is not supported, as not all platforms currently have this capability (i.e. the networking stack of Embassy) * "Unbound" UDP sockets are currently not supported, as not all platforms have these capabilities (i.e. the networking stack of Embassy). Also, I've yet to find a good use case for these. * Splittable sockets with `UdpSplit` -* `Multicast` trait for joining / leaving IPv4 and IPv6 multicast groups (can be optionally implemented by `UdpConnect` and `UdpBind`) +* `MulticastV4` and `MulticastV6` traits for joining / leaving IPv4 and IPv6 multicast groups (can be optionally implemented by `UdpConnect` and `UdpBind`) * `Readable` trait for waiting until a socket becomes readable ## Justification @@ -58,8 +59,8 @@ Namely: * Udp socket factory similar in spirit to STD's `std::net::UdpSocket::bind` method * [UdpConnect](src/stack/udp.rs) * Udp socket factory similar in spirit to STD's `std::net::UdpSocket::connect` method -* [Multicast](src/multicast.rs) - * Extra trait for UDP sockets allowing subscription to multicast groups +* [Multicastv4 and MulticastV6](src/multicast.rs) + * Extra traits for UDP sockets allowing subscription to multicast groups * [Readable](src/readable.rs) * Extra trait for UDP, TCP and raw sockets allowing one to wait until the socket becomes readable diff --git a/edge-nal/src/lib.rs b/edge-nal/src/lib.rs index ebdc537..2bb9279 100644 --- a/edge-nal/src/lib.rs +++ b/edge-nal/src/lib.rs @@ -4,6 +4,8 @@ pub use multicast::*; pub use raw::*; pub use readable::*; +pub use tcp::*; +pub use timeout::*; pub use udp::*; pub use stack::*; @@ -12,4 +14,6 @@ mod multicast; mod raw; mod readable; mod stack; +mod tcp; +mod timeout; mod udp; diff --git a/edge-nal/src/stack/tcp.rs b/edge-nal/src/stack/tcp.rs index 3ecf0f8..1294fba 100644 --- a/edge-nal/src/stack/tcp.rs +++ b/edge-nal/src/stack/tcp.rs @@ -4,7 +4,7 @@ use core::net::SocketAddr; use embedded_io_async::{Error, ErrorType, Read, Write}; -use crate::Readable; +use crate::{Readable, TcpShutdown}; /// This trait is implemented by TCP sockets that can be split into separate `send` and `receive` halves that can operate /// independently from each other (i.e., a full-duplex connection). @@ -46,6 +46,7 @@ pub trait TcpConnect { type Socket<'a>: Read + Write + Readable + + TcpShutdown where Self: 'a; @@ -80,6 +81,7 @@ pub trait TcpAccept { type Socket<'a>: Read + Write + Readable + + TcpShutdown where Self: 'a; diff --git a/edge-nal/src/tcp.rs b/edge-nal/src/tcp.rs new file mode 100644 index 0000000..294a997 --- /dev/null +++ b/edge-nal/src/tcp.rs @@ -0,0 +1,61 @@ +//! Trait for modeling TCP socket shutdown + +use embedded_io_async::ErrorType; + +/// Enum representing the different ways to close a TCP socket +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub enum Close { + /// Close the read half of the socket + Read, + /// Close the write half of the socket + Write, + /// Close both the read and write halves of the socket + Both, +} + +/// This trait is implemented by TCP sockets and models their shutdown functionality, +/// which is unique to the TCP protocol (UDP sockets do not have a shutdown procedure). +pub trait TcpShutdown: ErrorType { + /// Gracefully shutdown either or both the read and write halves of the socket. + /// + /// The write half is closed by sending a FIN packet to the peer and then waiting + /// until the FIN packet is ACKed. + /// + /// The read half is "closed" by reading from it until the peer indicates there is + /// no more data to read (i.e. it sends a FIN packet to the local socket). + /// Whether the other peer will send a FIN packet or not is not guaranteed, as that's + /// application protocol-specific. Usually, closing the write half means the peer will + /// notice and will send a FIN packet to the read half, thus "closing" it too. + /// + /// Note that - on certain platforms that don't have built-in timeouts - this method might never + /// complete if the peer is unreachable / misbehaving, so it has to be used with a + /// proper timeout in-place. + /// + /// Also note that calling this function multiple times may result in different behavior, + /// depending on the platform. + async fn close(&mut self, what: Close) -> Result<(), Self::Error>; + + /// Abort the connection, sending an RST packet to the peer + /// + /// This method will not wait forever, because the RST packet is not ACKed by the peer. + /// + /// Note that on certain platforms (STD for example) this method might be a no-op + /// as the connection there is automatically aborted when the socket is dropped. + /// + /// Also note that calling this function multiple times may result in different behavior, + /// depending on the platform. + async fn abort(&mut self) -> Result<(), Self::Error>; +} + +impl TcpShutdown for &mut T +where + T: TcpShutdown, +{ + async fn close(&mut self, what: Close) -> Result<(), Self::Error> { + (**self).close(what).await + } + + async fn abort(&mut self) -> Result<(), Self::Error> { + (**self).abort().await + } +} diff --git a/edge-nal/src/timeout.rs b/edge-nal/src/timeout.rs new file mode 100644 index 0000000..5f62494 --- /dev/null +++ b/edge-nal/src/timeout.rs @@ -0,0 +1,195 @@ +//! This module provides utility function and a decorator struct +//! for adding timeouts to IO types. +//! +//! Note that the presence of this module in the `edge-nal` crate +//! is a bit controversial, as it is a utility, while `edge-nal` is a +//! pure traits' crate otherwise. +//! +//! Therefore, the module might be moved to another location in future. + +use core::{ + fmt::{self, Display}, + future::Future, + net::SocketAddr, +}; + +use embassy_time::Duration; +use embedded_io_async::{ErrorKind, ErrorType, Read, Write}; + +use crate::{Readable, TcpConnect, TcpShutdown}; + +/// IO Error type for the `with_timeout` function and `WithTimeout` struct. +#[derive(Debug)] +pub enum WithTimeoutError { + /// An IO error occurred + IO(E), + /// The operation timed out + Timeout, +} + +impl From for WithTimeoutError { + fn from(e: E) -> Self { + Self::IO(e) + } +} + +impl fmt::Display for WithTimeoutError +where + E: Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::IO(e) => write!(f, "IO error: {}", e), + Self::Timeout => write!(f, "Operation timed out"), + } + } +} + +impl embedded_io_async::Error for WithTimeoutError +where + E: embedded_io_async::Error, +{ + fn kind(&self) -> ErrorKind { + match self { + Self::IO(e) => e.kind(), + Self::Timeout => ErrorKind::TimedOut, + } + } +} + +/// Run an IO future with a timeout. +/// +/// A future is an IO future if it resolves to a `Result`, where `E` +/// implements `embedded_io_async::Error`. +/// +/// If the future completes before the timeout, its output is returned. +/// Otherwise, on timeout, a timeout error is returned. +/// +/// Parameters: +/// - `timeout_ms`: The timeout duration in milliseconds +/// - `fut`: The future to run +pub async fn with_timeout(timeout_ms: u32, fut: F) -> Result> +where + F: Future>, + E: embedded_io_async::Error, +{ + map_result(embassy_time::with_timeout(Duration::from_millis(timeout_ms as _), fut).await) +} + +/// A type that wraps an IO stream type and adds a timeout to all operations. +/// +/// The operations decorated with a timeout are the ones offered via the following traits: +/// - `embedded_io_async::Read` +/// - `embedded_io_async::Write` +/// - `Readable` +/// - `TcpConnect` +/// - `TcpShutdown` +pub struct WithTimeout(T, u32); + +impl WithTimeout { + /// Create a new `WithTimeout` instance. + /// + /// Parameters: + /// - `timeout_ms`: The timeout duration in milliseconds + /// - `io`: The IO type to add a timeout to + pub const fn new(timeout_ms: u32, io: T) -> Self { + Self(io, timeout_ms) + } + + /// Get a reference to the inner IO type. + pub fn io(&self) -> &T { + &self.0 + } + + /// Get a mutable reference to the inner IO type. + pub fn io_mut(&mut self) -> &mut T { + &mut self.0 + } + + /// Get the IO type by destructuring the `WithTimeout` instance. + pub fn into_io(self) -> T { + self.0 + } +} + +impl ErrorType for WithTimeout +where + T: ErrorType, +{ + type Error = WithTimeoutError; +} + +impl Read for WithTimeout +where + T: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + with_timeout(self.1, self.0.read(buf)).await + } +} + +impl Write for WithTimeout +where + T: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + with_timeout(self.1, self.0.write(buf)).await + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + with_timeout(self.1, self.0.flush()).await + } +} + +impl TcpConnect for WithTimeout +where + T: TcpConnect, +{ + type Error = WithTimeoutError; + + type Socket<'a> + = WithTimeout> + where + Self: 'a; + + async fn connect(&self, remote: SocketAddr) -> Result, Self::Error> { + with_timeout(self.1, self.0.connect(remote)) + .await + .map(|s| WithTimeout::new(self.1, s)) + } +} + +impl Readable for WithTimeout +where + T: Readable, +{ + async fn readable(&mut self) -> Result<(), Self::Error> { + with_timeout(self.1, self.0.readable()).await + } +} + +impl TcpShutdown for WithTimeout +where + T: TcpShutdown, +{ + async fn close(&mut self, what: crate::Close) -> Result<(), Self::Error> { + with_timeout(self.1, self.0.close(what)).await + } + + async fn abort(&mut self) -> Result<(), Self::Error> { + with_timeout(self.1, self.0.abort()).await + } +} + +fn map_result( + result: Result, embassy_time::TimeoutError>, +) -> Result> +where + E: embedded_io_async::Error, +{ + match result { + Ok(Ok(t)) => Ok(t), + Ok(Err(e)) => Err(WithTimeoutError::IO(e)), + Err(_) => Err(WithTimeoutError::Timeout), + } +} diff --git a/edge-ws/README.md b/edge-ws/README.md index 681ed64..3423819 100644 --- a/edge-ws/README.md +++ b/edge-ws/README.md @@ -149,6 +149,7 @@ where ```rust use edge_http::io::server::{Connection, DefaultServer, Handler}; +use edge_http::io::Error; use edge_http::ws::MAX_BASE64_KEY_RESPONSE_LEN; use edge_http::Method; use edge_nal::TcpBind; @@ -177,19 +178,32 @@ pub async fn run(server: &mut DefaultServer) -> Result<(), anyhow::Error> { .bind(addr.parse().unwrap()) .await?; - server.run(acceptor, WsHandler, None).await?; + server + .run(acceptor, WsHandler, None, Some(30 * 60 * 1000)) + .await?; Ok(()) } +#[derive(Debug)] +enum WsHandlerError { + ConnectionError(C), + WsError(W), +} + +impl From for WsHandlerError { + fn from(e: C) -> Self { + Self::ConnectionError(e) + } +} + struct WsHandler; impl<'b, T, const N: usize> Handler<'b, T, N> for WsHandler where T: Read + Write, - T::Error: Send + Sync + std::error::Error + 'static, { - type Error = anyhow::Error; + type Error = WsHandlerError, edge_ws::Error>; async fn handle(&self, conn: &mut Connection<'b, T, N>) -> Result<(), Self::Error> { let headers = conn.headers()?; @@ -221,8 +235,13 @@ where let mut buf = [0_u8; 8192]; loop { - let mut header = FrameHeader::recv(&mut socket).await?; - let payload = header.recv_payload(&mut socket, &mut buf).await?; + let mut header = FrameHeader::recv(&mut socket) + .await + .map_err(WsHandlerError::WsError)?; + let payload = header + .recv_payload(&mut socket, &mut buf) + .await + .map_err(WsHandlerError::WsError)?; match header.frame_type { FrameType::Text(_) => { @@ -253,8 +272,14 @@ where info!("Echoing back as {header}"); - header.send(&mut socket).await?; - header.send_payload(&mut socket, payload).await?; + header + .send(&mut socket) + .await + .map_err(WsHandlerError::WsError)?; + header + .send_payload(&mut socket, payload) + .await + .map_err(WsHandlerError::WsError)?; } } diff --git a/examples/http_server.rs b/examples/http_server.rs index 2d0dd6c..5b01887 100644 --- a/examples/http_server.rs +++ b/examples/http_server.rs @@ -1,4 +1,5 @@ use edge_http::io::server::{Connection, DefaultServer, Handler}; +use edge_http::io::Error; use edge_http::Method; use edge_nal::TcpBind; @@ -25,7 +26,7 @@ pub async fn run(server: &mut DefaultServer) -> Result<(), anyhow::Error> { .bind(addr.parse().unwrap()) .await?; - server.run(acceptor, HttpHandler, None).await?; + server.run(acceptor, HttpHandler, None, None).await?; Ok(()) } @@ -35,9 +36,8 @@ struct HttpHandler; impl<'b, T, const N: usize> Handler<'b, T, N> for HttpHandler where T: Read + Write, - T::Error: Send + Sync + std::error::Error + 'static, { - type Error = anyhow::Error; + type Error = Error; async fn handle(&self, conn: &mut Connection<'b, T, N>) -> Result<(), Self::Error> { let headers = conn.headers()?; diff --git a/examples/ws_server.rs b/examples/ws_server.rs index 811ec8e..cdbd48a 100644 --- a/examples/ws_server.rs +++ b/examples/ws_server.rs @@ -1,4 +1,5 @@ use edge_http::io::server::{Connection, DefaultServer, Handler}; +use edge_http::io::Error; use edge_http::ws::MAX_BASE64_KEY_RESPONSE_LEN; use edge_http::Method; use edge_nal::TcpBind; @@ -27,19 +28,32 @@ pub async fn run(server: &mut DefaultServer) -> Result<(), anyhow::Error> { .bind(addr.parse().unwrap()) .await?; - server.run(acceptor, WsHandler, None).await?; + server + .run(acceptor, WsHandler, None, Some(30 * 60 * 1000)) + .await?; Ok(()) } +#[derive(Debug)] +enum WsHandlerError { + ConnectionError(C), + WsError(W), +} + +impl From for WsHandlerError { + fn from(e: C) -> Self { + Self::ConnectionError(e) + } +} + struct WsHandler; impl<'b, T, const N: usize> Handler<'b, T, N> for WsHandler where T: Read + Write, - T::Error: Send + Sync + std::error::Error + 'static, { - type Error = anyhow::Error; + type Error = WsHandlerError, edge_ws::Error>; async fn handle(&self, conn: &mut Connection<'b, T, N>) -> Result<(), Self::Error> { let headers = conn.headers()?; @@ -71,8 +85,13 @@ where let mut buf = [0_u8; 8192]; loop { - let mut header = FrameHeader::recv(&mut socket).await?; - let payload = header.recv_payload(&mut socket, &mut buf).await?; + let mut header = FrameHeader::recv(&mut socket) + .await + .map_err(WsHandlerError::WsError)?; + let payload = header + .recv_payload(&mut socket, &mut buf) + .await + .map_err(WsHandlerError::WsError)?; match header.frame_type { FrameType::Text(_) => { @@ -103,8 +122,14 @@ where info!("Echoing back as {header}"); - header.send(&mut socket).await?; - header.send_payload(&mut socket, payload).await?; + header + .send(&mut socket) + .await + .map_err(WsHandlerError::WsError)?; + header + .send_payload(&mut socket, payload) + .await + .map_err(WsHandlerError::WsError)?; } }