diff --git a/tokio/src/macros/debug_check_std_blocking.rs b/tokio/src/macros/debug_check_std_blocking.rs new file mode 100644 index 00000000000..50ed9e135dc --- /dev/null +++ b/tokio/src/macros/debug_check_std_blocking.rs @@ -0,0 +1,88 @@ +//! Debug assertions that from_std_assume_nonblocking is not used on blocking sockets. +//! +//! These are displayed warnings in debug mode, panics in test mode +//! (so that nothing slips through in the tokio test suite), and no-ops in release mode. + +#[cfg(all(debug_assertions, not(test)))] +/// Debug assertions that from_std_assume_nonblocking is not used on blocking sockets. +/// +/// These are displayed warnings in debug mode, panics in test mode +/// (so that nothing slips through in the tokio test suite), and no-ops in release mode. +macro_rules! debug_check_non_blocking { + ($std_socket: expr, $method: expr, $fallback_method: expr) => {{ + // Make sure the provided item is in non-blocking mode, otherwise warn. + static HAS_WARNED_BLOCKING: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); + match socket2::SockRef::from(&$std_socket).nonblocking() { + Ok(true) => {} + Ok(false) => { + if !HAS_WARNED_BLOCKING.swap(true, std::sync::atomic::Ordering::Relaxed) { + println!(concat!( + "WARNING: `", + $method, + "` was called on a socket that is \ + not in non-blocking mode. This is unexpected, and may cause the \ + thread to block indefinitely. Use `", + $fallback_method, + "` instead." + )); + } + } + Err(io_error) => { + if !HAS_WARNED_BLOCKING.swap(true, std::sync::atomic::Ordering::Relaxed) { + println!( + concat!( + "WARNING: `", + $method, + "` was called on a socket which we \ + could not determine whether was in non-blocking mode: {}" + ), + io_error + ); + } + } + } + }}; +} + +#[cfg(test)] +/// Debug assertions that from_std_assume_nonblocking is not used on blocking sockets. +/// +/// These are displayed warnings in debug mode, panics in test mode +/// (so that nothing slips through in the tokio test suite), and no-ops in release mode. +macro_rules! debug_check_non_blocking { + ($std_socket: expr, $method: expr, $fallback_method: expr) => {{ + match socket2::SockRef::from(&$std_socket).nonblocking() { + Ok(true) => {} + Ok(false) => { + panic!(concat!( + $method, + "` was called on a socket that is \ + not in non-blocking mode. This is unexpected, and may cause the \ + thread to block indefinitely. Use `", + $fallback_method, + "` instead." + )) + } + Err(io_error) => { + panic!( + concat!( + $method, + "` was called on a socket which we \ + could not determine whether was in non-blocking mode: {}" + ), + io_error + ); + } + } + }}; +} + +#[cfg(not(debug_assertions))] +/// Debug assertions that from_std_assume_nonblocking is not used on blocking sockets. +/// +/// These are displayed warnings in debug mode, panics in test mode +/// (so that nothing slips through in the tokio test suite), and no-ops in release mode. +macro_rules! debug_check_non_blocking { + ($($tts:tt)+) => {}; +} diff --git a/tokio/src/macros/mod.rs b/tokio/src/macros/mod.rs index 82f42dbff35..979ec1f736e 100644 --- a/tokio/src/macros/mod.rs +++ b/tokio/src/macros/mod.rs @@ -37,3 +37,6 @@ cfg_macros! { // Includes re-exports needed to implement macros #[doc(hidden)] pub mod support; + +#[macro_use] +mod debug_check_std_blocking; diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index 28da34afb29..ac5e9cc4f77 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -203,13 +203,54 @@ impl TcpListener { /// /// # Notes /// - /// The caller is responsible for ensuring that the listener is in - /// non-blocking mode. Otherwise all I/O operations on the listener + /// This sets the socket to non-blocking mode if it isn't already non-blocking. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// use tokio::net::TcpListener; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?; + /// let listener = TcpListener::from_std_set_nonblocking(std_listener)?; + /// Ok(()) + /// } + /// ``` + /// + /// # Panics + /// + /// This function panics if it is not called from within a runtime with + /// IO enabled. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + #[track_caller] + pub fn from_std_set_nonblocking(listener: net::TcpListener) -> io::Result { + listener.set_nonblocking(true)?; + Self::from_std_assume_nonblocking(listener) + } + + /// Creates new `TcpListener` from a `std::net::TcpListener` without + /// checking that it's non-blocking. + /// + /// This function is intended to be used to wrap a TCP listener from the + /// standard library in the Tokio equivalent. However, it does not check + /// that it's already in non-blocking mode. + /// + /// Instead, the caller is responsible for ensuring that the listener is in + /// non-blocking mode, otherwise all I/O operations on the listener /// will block the thread, which will cause unexpected behavior. /// Non-blocking mode can be set using [`set_nonblocking`]. /// /// [`set_nonblocking`]: std::net::TcpListener::set_nonblocking /// + /// It may be preferrable to use + /// [`from_std_set_nonblocking`](TcpListener::from_std_set_nonblocking), + /// which sets `nonblocking`. + /// /// # Examples /// /// ```rust,no_run @@ -220,7 +261,7 @@ impl TcpListener { /// async fn main() -> Result<(), Box> { /// let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?; /// std_listener.set_nonblocking(true)?; - /// let listener = TcpListener::from_std(std_listener)?; + /// let listener = TcpListener::from_std_assume_nonblocking(std_listener)?; /// Ok(()) /// } /// ``` @@ -234,12 +275,34 @@ impl TcpListener { /// from a future driven by a tokio runtime, otherwise runtime can be set /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. #[track_caller] - pub fn from_std(listener: net::TcpListener) -> io::Result { + pub fn from_std_assume_nonblocking(listener: net::TcpListener) -> io::Result { + debug_check_non_blocking!( + listener, + "TcpListener::from_std", + "TcpListener::from_std_set_nonblocking" + ); let io = mio::net::TcpListener::from_std(listener); let io = PollEvented::new(io)?; Ok(TcpListener { io }) } + #[doc(hidden)] + /// Creates new `TcpListener` from a `std::net::TcpListener`. + /// + /// This function is doc-hidden because it's easy to misuse, + /// and naming doesn't warn enough about it. + /// + /// It may be preferrable to use + /// [`from_std_set_nonblocking`](TcpListener::from_std_set_nonblocking), + /// which sets `nonblocking`. + /// + /// This function however has the same behavior as + /// [`TcpListener::from_std_assume_nonblocking`]. + #[track_caller] + pub fn from_std(listener: net::TcpListener) -> io::Result { + Self::from_std_assume_nonblocking(listener) + } + /// Turns a [`tokio::net::TcpListener`] into a [`std::net::TcpListener`]. /// /// The returned [`std::net::TcpListener`] will have nonblocking mode set as @@ -384,9 +447,22 @@ impl TryFrom for TcpListener { /// Consumes stream, returning the tokio I/O object. /// /// This is equivalent to - /// [`TcpListener::from_std(stream)`](TcpListener::from_std). + /// [`TcpListener::from_std_assume_nonblocking(stream)`](TcpListener::from_std_assume_nonblocking). + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the listener is in + /// non-blocking mode. Otherwise all I/O operations on the listener + /// will block the thread, which will cause unexpected behavior. + /// Non-blocking mode can be set using [`set_nonblocking`]. + /// + /// [`set_nonblocking`]: std::net::TcpListener::set_nonblocking + /// + /// It may be preferrable to use + /// [`from_std_set_nonblocking`](TcpListener::from_std_set_nonblocking), + /// which sets `nonblocking`. fn try_from(stream: net::TcpListener) -> Result { - Self::from_std(stream) + Self::from_std_assume_nonblocking(stream) } } diff --git a/tokio/tests/io_driver_drop.rs b/tokio/tests/io_driver_drop.rs index c3182637916..7e1fc4cc181 100644 --- a/tokio/tests/io_driver_drop.rs +++ b/tokio/tests/io_driver_drop.rs @@ -12,7 +12,7 @@ fn tcp_doesnt_block() { let listener = { let _enter = rt.enter(); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - TcpListener::from_std(listener).unwrap() + TcpListener::from_std(listener).unwrap() // Given the name of this test it looks like it hasn't fulfilled its purpose when going from mio 0.6 to mio 0.7 -> TODO investigate why }; drop(rt); diff --git a/tokio/tests/net_panic.rs b/tokio/tests/net_panic.rs index 79807618973..b2954b350d7 100644 --- a/tokio/tests/net_panic.rs +++ b/tokio/tests/net_panic.rs @@ -40,7 +40,7 @@ fn tcp_listener_from_std_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = runtime_without_io(); rt.block_on(async { - let _ = TcpListener::from_std(std_listener); + let _ = TcpListener::from_std_assume_nonblocking(std_listener); }); }); diff --git a/tokio/tests/no_rt.rs b/tokio/tests/no_rt.rs index 89c7ce0aa57..306132f3f13 100644 --- a/tokio/tests/no_rt.rs +++ b/tokio/tests/no_rt.rs @@ -37,5 +37,7 @@ async fn timeout_value() { expected = "there is no reactor running, must be called from the context of a Tokio 1.x runtime" )] fn io_panics_when_no_tokio_context() { - let _ = tokio::net::TcpListener::from_std(std::net::TcpListener::bind("127.0.0.1:0").unwrap()); + let _ = tokio::net::TcpListener::from_std_assume_nonblocking( + std::net::TcpListener::bind("127.0.0.1:0").unwrap(), + ); }