diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 76800296a27..bd01d146609 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -927,7 +927,7 @@ impl Builder { #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] pub fn build_local(&mut self, options: LocalOptions) -> io::Result { match &self.kind { - Kind::CurrentThread => self.build_current_thread_local_runtime(), + Kind::CurrentThread => self.build_current_thread_local_runtime(options), #[cfg(feature = "rt-multi-thread")] Kind::MultiThread => panic!("multi_thread is not supported for LocalRuntime"), } @@ -1439,11 +1439,16 @@ impl Builder { } #[cfg(tokio_unstable)] - fn build_current_thread_local_runtime(&mut self) -> io::Result { + fn build_current_thread_local_runtime( + &mut self, + opts: LocalOptions, + ) -> io::Result { use crate::runtime::local_runtime::LocalRuntimeScheduler; let tid = std::thread::current().id(); + self.before_park = opts.before_park; + self.after_unpark = opts.after_unpark; let (scheduler, handle, blocking_pool) = self.build_current_thread_runtime_components(Some(tid))?; diff --git a/tokio/src/runtime/local_runtime/options.rs b/tokio/src/runtime/local_runtime/options.rs index ef276e2c9be..7d0b1f0b531 100644 --- a/tokio/src/runtime/local_runtime/options.rs +++ b/tokio/src/runtime/local_runtime/options.rs @@ -1,18 +1,159 @@ use std::marker::PhantomData; +use crate::runtime::Callback; + /// [`LocalRuntime`]-only config options /// -/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may -/// be added. -/// /// Use `LocalOptions::default()` to create the default set of options. This type is used with /// [`Builder::build_local`]. /// +/// When using [`Builder::build_local`], this overrides any pre-configured options set on the +/// [`Builder`]. +/// /// [`Builder::build_local`]: crate::runtime::Builder::build_local /// [`LocalRuntime`]: crate::runtime::LocalRuntime -#[derive(Default, Debug)] +/// [`Builder`]: crate::runtime::Builder +#[derive(Default)] #[non_exhaustive] +#[allow(missing_debug_implementations)] pub struct LocalOptions { /// Marker used to make this !Send and !Sync. _phantom: PhantomData<*mut u8>, + + /// To run before the local runtime is parked. + pub(crate) before_park: Option, + + /// To run before the local runtime is spawned. + pub(crate) after_unpark: Option, +} + +impl std::fmt::Debug for LocalOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LocalOptions") + .field("before_park", &self.before_park.as_ref().map(|_| "...")) + .field("after_unpark", &self.after_unpark.as_ref().map(|_| "...")) + .finish() + } +} + +impl LocalOptions { + /// Executes function `f` just before the local runtime is parked (goes idle). + /// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn) + /// can be called, and may result in this thread being unparked immediately. + /// + /// This can be used to start work only when the executor is idle, or for bookkeeping + /// and monitoring purposes. + /// + /// This differs from the [`Builder::on_thread_park`] method in that it accepts a non Send + Sync + /// closure. + /// + /// Note: There can only be one park callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime::{Builder, LocalOptions}; + /// # pub fn main() { + /// let (tx, rx) = std::sync::mpsc::channel(); + /// let mut opts = LocalOptions::default(); + /// opts.on_thread_park(move || match rx.recv() { + /// Ok(x) => println!("Received from channel: {}", x), + /// Err(e) => println!("Error receiving from channel: {}", e), + /// }); + /// + /// let runtime = Builder::new_current_thread() + /// .enable_time() + /// .build_local(opts) + /// .unwrap(); + /// + /// runtime.block_on(async { + /// tokio::task::spawn_local(async move { + /// tx.send(42).unwrap(); + /// }); + /// tokio::time::sleep(std::time::Duration::from_millis(1)).await; + /// }) + /// # } + /// ``` + /// + /// [`Builder`]: crate::runtime::Builder + /// [`Builder::on_thread_park`]: crate::runtime::Builder::on_thread_park + pub fn on_thread_park(&mut self, f: F) -> &mut Self + where + F: Fn() + 'static, + { + self.before_park = Some(std::sync::Arc::new(to_send_sync(f))); + self + } + + /// Executes function `f` just after the local runtime unparks (starts executing tasks). + /// + /// This is intended for bookkeeping and monitoring use cases; note that work + /// in this callback will increase latencies when the application has allowed one or + /// more runtime threads to go idle. + /// + /// This differs from the [`Builder::on_thread_unpark`] method in that it accepts a non Send + Sync + /// closure. + /// + /// Note: There can only be one unpark callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime::{Builder, LocalOptions}; + /// # pub fn main() { + /// let (tx, rx) = std::sync::mpsc::channel(); + /// let mut opts = LocalOptions::default(); + /// opts.on_thread_unpark(move || match rx.recv() { + /// Ok(x) => println!("Received from channel: {}", x), + /// Err(e) => println!("Error receiving from channel: {}", e), + /// }); + /// + /// let runtime = Builder::new_current_thread() + /// .enable_time() + /// .build_local(opts) + /// .unwrap(); + /// + /// runtime.block_on(async { + /// tokio::task::spawn_local(async move { + /// tx.send(42).unwrap(); + /// }); + /// tokio::time::sleep(std::time::Duration::from_millis(1)).await; + /// }) + /// # } + /// ``` + /// + /// [`Builder`]: crate::runtime::Builder + /// [`Builder::on_thread_unpark`]: crate::runtime::Builder::on_thread_unpark + pub fn on_thread_unpark(&mut self, f: F) -> &mut Self + where + F: Fn() + 'static, + { + self.after_unpark = Some(std::sync::Arc::new(to_send_sync(f))); + self + } +} + +// A wrapper type to allow non-Send + Sync closures to be used in a Send + Sync context. +// This is specifically used for executing callbacks when using a `LocalRuntime`. +struct UnsafeSendSync(T); + +// SAFETY: This type is only used in a context where it is guaranteed that the closure will not be +// sent across threads. +unsafe impl Send for UnsafeSendSync {} +unsafe impl Sync for UnsafeSendSync {} + +impl UnsafeSendSync { + fn call(&self) { + (self.0)() + } +} + +fn to_send_sync(f: F) -> impl Fn() + Send + Sync +where + F: Fn(), +{ + let f = UnsafeSendSync(f); + move || f.call() } diff --git a/tokio/tests/rt_local.rs b/tokio/tests/rt_local.rs index 4eb88d48a4d..d7e8d7a40bc 100644 --- a/tokio/tests/rt_local.rs +++ b/tokio/tests/rt_local.rs @@ -6,7 +6,7 @@ use tokio::task::spawn_local; #[test] fn test_spawn_local_in_runtime() { - let rt = rt(); + let rt = rt(LocalOptions::default()); let res = rt.block_on(async move { let (tx, rx) = tokio::sync::oneshot::channel(); @@ -22,9 +22,43 @@ fn test_spawn_local_in_runtime() { assert_eq!(res, 5); } +#[test] +fn test_on_thread_park_unpark_in_runtime() { + let mut opts = LocalOptions::default(); + + // the refcell makes the below callbacks `!Send + !Sync` + let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false)); + let on_park_cc = on_park_called.clone(); + opts.on_thread_park(move || { + *on_park_cc.borrow_mut() = true; + }); + + let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false)); + let on_unpark_cc = on_unpark_called.clone(); + opts.on_thread_unpark(move || { + *on_unpark_cc.borrow_mut() = true; + }); + let rt = rt(opts); + + rt.block_on(async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + + spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + // this ensures on_thread_park is called + rx.await.unwrap() + }); + + assert!(*on_park_called.borrow()); + assert!(*on_unpark_called.borrow()); +} + #[test] fn test_spawn_from_handle() { - let rt = rt(); + let rt = rt(LocalOptions::default()); let (tx, rx) = tokio::sync::oneshot::channel(); @@ -40,7 +74,7 @@ fn test_spawn_from_handle() { #[test] fn test_spawn_local_on_runtime_object() { - let rt = rt(); + let rt = rt(LocalOptions::default()); let (tx, rx) = tokio::sync::oneshot::channel(); @@ -56,7 +90,7 @@ fn test_spawn_local_on_runtime_object() { #[test] fn test_spawn_local_from_guard() { - let rt = rt(); + let rt = rt(LocalOptions::default()); let (tx, rx) = tokio::sync::oneshot::channel(); @@ -78,7 +112,7 @@ fn test_spawn_from_guard_other_thread() { let (tx, rx) = std::sync::mpsc::channel(); std::thread::spawn(move || { - let rt = rt(); + let rt = rt(LocalOptions::default()); let handle = rt.handle().clone(); tx.send(handle).unwrap(); @@ -98,7 +132,7 @@ fn test_spawn_local_from_guard_other_thread() { let (tx, rx) = std::sync::mpsc::channel(); std::thread::spawn(move || { - let rt = rt(); + let rt = rt(LocalOptions::default()); let handle = rt.handle().clone(); tx.send(handle).unwrap(); @@ -111,9 +145,9 @@ fn test_spawn_local_from_guard_other_thread() { spawn_local(async {}); } -fn rt() -> tokio::runtime::LocalRuntime { +fn rt(opts: LocalOptions) -> tokio::runtime::LocalRuntime { tokio::runtime::Builder::new_current_thread() .enable_all() - .build_local(LocalOptions::default()) + .build_local(opts) .unwrap() }