diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 9aae69ab98f..01d5dcba780 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -3,7 +3,10 @@ use crate::runtime::handle::Handle; use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; #[cfg(tokio_unstable)] -use crate::runtime::{metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta}; +use crate::runtime::{ + metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta, TaskSpawnCallback, + UserData, +}; use crate::util::rand::{RngSeed, RngSeedGenerator}; use crate::runtime::blocking::BlockingPool; @@ -89,6 +92,9 @@ pub struct Builder { pub(super) after_unpark: Option, /// To run before each task is spawned. + #[cfg(tokio_unstable)] + pub(super) before_spawn: Option, + #[cfg(not(tokio_unstable))] pub(super) before_spawn: Option, /// To run before each poll @@ -731,8 +737,19 @@ impl Builder { /// Executes function `f` just before a task is spawned. /// /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. + /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback + /// being invoked immediately. + /// + /// `f` must return an `Option<&'static (dyn Any + Send + Sync)>`. The `Send + Sync` + /// traits are optional when not using feature `rt-multi-thread`. A value returned + /// by this callback is attached to the task and can be retrieved using + /// [`TaskMeta::get_data`] in subsequent calls to other hooks for this task such as + /// [`on_before_task_poll`](crate::runtime::Builder::on_before_task_poll), + /// [`on_after_task_poll`](crate::runtime::Builder::on_after_task_poll), and + /// [`on_task_terminate`](crate::runtime::Builder::on_task_terminate). + /// + /// The `crate::task::Builder::data` method can also be used to attach data to + /// a specific task when spawning it. /// /// This can be used for bookkeeping or monitoring purposes. /// @@ -755,6 +772,7 @@ impl Builder { /// let runtime = runtime::Builder::new_current_thread() /// .on_task_spawn(|_| { /// println!("spawning task"); + /// None /// }) /// .build() /// .unwrap(); @@ -768,11 +786,53 @@ impl Builder { /// }) /// # } /// ``` + /// + /// ``` + /// # use tokio::runtime; + /// # use std::sync::atomic::{AtomicUsize, Ordering}; + /// # pub fn main() { + /// struct YieldingTaskMetadata { + /// pub yield_count: AtomicUsize, + /// } + /// let runtime = runtime::Builder::new_current_thread() + /// .on_task_spawn(|meta| { + /// println!("spawning task {}", meta.id()); + /// let meta = Box::new(YieldingTaskMetadata { yield_count: AtomicUsize::new(0) }); + /// Some(Box::leak(meta) as &(dyn std::any::Any + Send + Sync)) + /// }) + /// .on_after_task_poll(|meta| { + /// if let Some(data) = meta.get_data().and_then(|data| data.downcast_ref::()) { + /// println!("task {} yield count: {}", meta.id(), data.yield_count.fetch_add(1, Ordering::Relaxed)); + /// } + /// }) + /// .on_task_terminate(|meta| { + /// match meta.get_data().and_then(|data| data.downcast_ref::()) { + /// Some(data) => { + /// let yield_count = data.yield_count.load(Ordering::Relaxed); + /// println!("task {} total yield count: {}", meta.id(), yield_count); + /// assert!(yield_count == 64); + /// }, + /// None => panic!("task has missing or incorrect user data"), + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// let _ = tokio::task::spawn(async { + /// for _ in 0..64 { + /// println!("yielding"); + /// tokio::task::yield_now().await; + /// } + /// }).await.unwrap(); + /// }) + /// # } + /// ``` #[cfg(all(not(loom), tokio_unstable))] #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] pub fn on_task_spawn(&mut self, f: F) -> &mut Self where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, + F: Fn(&TaskMeta<'_>) -> UserData + Send + Sync + 'static, { self.before_spawn = Some(std::sync::Arc::new(f)); self diff --git a/tokio/src/runtime/config.rs b/tokio/src/runtime/config.rs index b79df96e1e2..9a6348015b0 100644 --- a/tokio/src/runtime/config.rs +++ b/tokio/src/runtime/config.rs @@ -2,6 +2,8 @@ any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"), allow(dead_code) )] +#[cfg(tokio_unstable)] +use crate::runtime::TaskSpawnCallback; use crate::runtime::{Callback, TaskCallback}; use crate::util::RngSeedGenerator; @@ -19,6 +21,9 @@ pub(crate) struct Config { pub(crate) after_unpark: Option, /// To run before each task is spawned. + #[cfg(tokio_unstable)] + pub(crate) before_spawn: Option, + #[cfg(not(tokio_unstable))] pub(crate) before_spawn: Option, /// To run after each task is terminated. diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index ae58ce6da86..71f5cc5c053 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -436,6 +436,10 @@ cfg_rt! { mod task_hooks; pub(crate) use task_hooks::{TaskHooks, TaskCallback}; + #[cfg(tokio_unstable)] + pub(crate) use task_hooks::{TaskSpawnCallback, UserData}; + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) use task_hooks::UserDataValue; cfg_unstable! { pub use task_hooks::TaskMeta; } diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index a05dbb96412..9c3bdecdaf3 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -5,6 +5,8 @@ use crate::runtime::scheduler::{self, Defer, Inject}; use crate::runtime::task::{ self, JoinHandle, OwnedTasks, Schedule, SpawnLocation, Task, TaskHarnessScheduleHooks, }; +#[cfg(tokio_unstable)] +use crate::runtime::UserData; use crate::runtime::{ blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics, }; @@ -456,13 +458,49 @@ impl Handle { F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, spawned_at); + Self::spawn_with_user_data( + me, + future, + id, + spawned_at, + #[cfg(tokio_unstable)] + None, + ) + } - me.task_hooks.spawn(&TaskMeta { + #[track_caller] + pub(crate) fn spawn_with_user_data( + me: &Arc, + future: F, + id: crate::runtime::task::Id, + spawned_at: SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, + ) -> JoinHandle + where + F: crate::future::Future + Send + 'static, + F::Output: Send + 'static, + { + let task_meta = TaskMeta { id, spawned_at, + #[cfg(tokio_unstable)] + user_data, _phantom: Default::default(), - }); + }; + + #[cfg(not(tokio_unstable))] + { + me.task_hooks.spawn(&task_meta); + } + + let (handle, notified) = me.shared.owned.bind( + future, + me.clone(), + id, + spawned_at, + #[cfg(tokio_unstable)] + me.task_hooks.spawn(&task_meta), + ); if let Some(notified) = notified { me.schedule(notified); @@ -488,16 +526,27 @@ impl Handle { F: crate::future::Future + 'static, F::Output: 'static, { - let (handle, notified) = me - .shared - .owned - .bind_local(future, me.clone(), id, spawned_at); - - me.task_hooks.spawn(&TaskMeta { + let task_meta = TaskMeta { id, spawned_at, + #[cfg(tokio_unstable)] + user_data: None, _phantom: Default::default(), - }); + }; + + #[cfg(not(tokio_unstable))] + { + me.task_hooks.spawn(&task_meta); + } + + let (handle, notified) = me.shared.owned.bind_local( + future, + me.clone(), + id, + spawned_at, + #[cfg(tokio_unstable)] + me.task_hooks.spawn(&task_meta), + ); if let Some(notified) = notified { me.schedule(notified); diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index ecd56aeee10..7f529230e1b 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -68,6 +68,8 @@ impl Handle { cfg_rt! { use crate::future::Future; use crate::loom::sync::Arc; + #[cfg(tokio_unstable)] + use crate::runtime::UserData; use crate::runtime::{blocking, task::{Id, SpawnLocation}}; use crate::runtime::context; use crate::task::JoinHandle; @@ -130,6 +132,20 @@ cfg_rt! { } } + #[cfg(tokio_unstable)] + pub(crate) fn spawn_with_user_data(&self, future: F, id: Id, spawned_at: SpawnLocation, user_data: UserData) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + match self { + Handle::CurrentThread(h) => current_thread::Handle::spawn_with_user_data(h, future, id, spawned_at, user_data), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(h) => multi_thread::Handle::spawn_with_user_data(h, future, id, spawned_at, user_data), + } + } + /// Spawn a local task /// /// # Safety diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index 9acfcb270d6..d26d873abfb 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -2,6 +2,8 @@ use crate::future::Future; use crate::loom::sync::Arc; use crate::runtime::scheduler::multi_thread::worker; use crate::runtime::task::{Notified, Task, TaskHarnessScheduleHooks}; +#[cfg(tokio_unstable)] +use crate::runtime::task_hooks::UserData; use crate::runtime::{ blocking, driver, task::{self, JoinHandle, SpawnLocation}, @@ -47,7 +49,30 @@ impl Handle { F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - Self::bind_new_task(me, future, id, spawned_at) + Self::bind_new_task( + me, + future, + id, + spawned_at, + #[cfg(tokio_unstable)] + None, + ) + } + + /// Spawns a future with user data onto the thread pool + #[cfg(tokio_unstable)] + pub(crate) fn spawn_with_user_data( + me: &Arc, + future: F, + id: task::Id, + spawned_at: SpawnLocation, + user_data: UserData, + ) -> JoinHandle + where + F: crate::future::Future + Send + 'static, + F::Output: Send + 'static, + { + Self::bind_new_task(me, future, id, spawned_at, user_data) } pub(crate) fn shutdown(&self) { @@ -60,18 +85,33 @@ impl Handle { future: T, id: task::Id, spawned_at: SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, spawned_at); - - me.task_hooks.spawn(&TaskMeta { + let task_meta = TaskMeta { id, spawned_at, + #[cfg(tokio_unstable)] + user_data, _phantom: Default::default(), - }); + }; + + #[cfg(not(tokio_unstable))] + { + me.task_hooks.spawn(&task_meta); + } + + let (handle, notified) = me.shared.owned.bind( + future, + me.clone(), + id, + spawned_at, + #[cfg(tokio_unstable)] + me.task_hooks.spawn(&task_meta), + ); me.schedule_option_task_without_yield(notified); diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index e91e8be4025..9a7e7a6369e 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -15,6 +15,8 @@ use crate::runtime::context; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks}; +#[cfg(tokio_unstable)] +use crate::runtime::task_hooks::UserData; use crate::util::linked_list; use std::num::NonZeroU64; @@ -182,6 +184,10 @@ pub(crate) struct Header { /// The tracing ID for this instrumented task. #[cfg(all(tokio_unstable, feature = "tracing"))] pub(super) tracing_id: Option, + + /// Custom user defined metadata for this task for use in hooks. + #[cfg(tokio_unstable)] + pub(super) user_data: UserData, } unsafe impl Send for Header {} @@ -223,12 +229,14 @@ impl Cell { state: State, task_id: Id, #[cfg(tokio_unstable)] spawned_at: &'static Location<'static>, + #[cfg(tokio_unstable)] user_data: UserData, ) -> Box> { // Separated into a non-generic function to reduce LLVM codegen fn new_header( state: State, vtable: &'static Vtable, #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id: Option, + #[cfg(tokio_unstable)] user_data: UserData, ) -> Header { Header { state, @@ -237,6 +245,8 @@ impl Cell { owner_id: UnsafeCell::new(None), #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id, + #[cfg(tokio_unstable)] + user_data, } } @@ -250,6 +260,8 @@ impl Cell { vtable, #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id, + #[cfg(tokio_unstable)] + user_data, ), core: Core { scheduler, @@ -515,6 +527,16 @@ impl Header { *ptr } + /// Gets the user data from the task header. + /// + /// # Safety + /// + /// The provided raw pointer must point at the header of a task. + #[cfg(tokio_unstable)] + pub(super) unsafe fn get_user_data(me: NonNull
) -> UserData { + me.as_ref().user_data + } + /// Gets the tracing id of the task containing this `Header`. /// /// # Safety diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 6f20d66efc6..4a1a210a29e 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -374,6 +374,8 @@ where f(&TaskMeta { id: self.core().task_id, spawned_at: self.core().spawned_at.into(), + #[cfg(tokio_unstable)] + user_data: self.header().user_data, _phantom: Default::default(), }) })); diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 908ce07ecf6..9a707adcc61 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -13,6 +13,8 @@ use crate::util::linked_list::{Link, LinkedList}; use crate::util::sharded_list; use crate::loom::sync::atomic::{AtomicBool, Ordering}; +#[cfg(tokio_unstable)] +use crate::runtime::task_hooks::UserData; use std::marker::PhantomData; use std::num::NonZeroU64; @@ -92,13 +94,21 @@ impl OwnedTasks { scheduler: S, id: super::Id, spawned_at: SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + spawned_at, + #[cfg(tokio_unstable)] + user_data, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -113,13 +123,21 @@ impl OwnedTasks { scheduler: S, id: super::Id, spawned_at: SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + spawned_at, + #[cfg(tokio_unstable)] + user_data, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -261,13 +279,21 @@ impl LocalOwnedTasks { scheduler: S, id: super::Id, spawned_at: SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + spawned_at, + #[cfg(tokio_unstable)] + user_data, + ); unsafe { // safety: We just created the task, so we have exclusive access diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 093b6f6caad..67313e0a820 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -232,6 +232,9 @@ use std::panic::Location; use std::ptr::NonNull; use std::{fmt, mem}; +#[cfg(tokio_unstable)] +use crate::runtime::task_hooks::UserData; + /// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task { @@ -330,6 +333,7 @@ cfg_rt! { scheduler: S, id: Id, spawned_at: SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, ) -> (Task, Notified, JoinHandle) where S: Schedule, @@ -341,6 +345,8 @@ cfg_rt! { scheduler, id, spawned_at, + #[cfg(tokio_unstable)] + user_data, ); let task = Task { raw, @@ -375,6 +381,8 @@ cfg_rt! { scheduler, id, spawned_at, + #[cfg(tokio_unstable)] + None, ); // This transfers the ref-count of task and notified into an UnownedTask. @@ -437,6 +445,12 @@ impl Task { unsafe { Header::get_spawn_location(self.raw.header_ptr()) } } + #[cfg(tokio_unstable)] + pub(crate) fn get_user_data(&self) -> UserData { + // Safety: The header pointer is valid. + unsafe { Header::get_user_data(self.raw.header_ptr()) } + } + // Explicit `'task` and `'meta` lifetimes are necessary here, as otherwise, // the compiler infers the lifetimes to be the same, and considers the task // to be borrowed for the lifetime of the returned `TaskMeta`. @@ -445,6 +459,8 @@ impl Task { crate::runtime::TaskMeta { id: self.id(), spawned_at: self.spawned_at().into(), + #[cfg(tokio_unstable)] + user_data: self.get_user_data(), _phantom: PhantomData, } } diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index e9a37802203..2695204f992 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -2,6 +2,8 @@ use crate::future::Future; use crate::runtime::task::core::{Core, Trailer}; use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; #[cfg(tokio_unstable)] +use crate::runtime::UserData; +#[cfg(tokio_unstable)] use std::panic::Location; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -204,6 +206,7 @@ impl RawTask { scheduler: S, id: Id, _spawned_at: super::SpawnLocation, + #[cfg(tokio_unstable)] user_data: UserData, ) -> RawTask where T: Future, @@ -216,6 +219,8 @@ impl RawTask { id, #[cfg(tokio_unstable)] _spawned_at.0, + #[cfg(tokio_unstable)] + user_data, )); let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; diff --git a/tokio/src/runtime/task_hooks.rs b/tokio/src/runtime/task_hooks.rs index 6df3837b527..3a9d26e1177 100644 --- a/tokio/src/runtime/task_hooks.rs +++ b/tokio/src/runtime/task_hooks.rs @@ -1,7 +1,18 @@ use super::Config; +#[cfg(tokio_unstable)] +use std::any::Any; use std::marker::PhantomData; impl TaskHooks { + #[cfg(tokio_unstable)] + pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) -> UserData { + match self.task_spawn_callback.as_ref() { + Some(f) => f(meta), + None => None, + } + } + + #[cfg(not(tokio_unstable))] pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) { if let Some(f) = self.task_spawn_callback.as_ref() { f(meta) @@ -39,6 +50,9 @@ impl TaskHooks { #[derive(Clone)] pub(crate) struct TaskHooks { + #[cfg(tokio_unstable)] + pub(crate) task_spawn_callback: Option, + #[cfg(not(tokio_unstable))] pub(crate) task_spawn_callback: Option, pub(crate) task_terminate_callback: Option, #[cfg(tokio_unstable)] @@ -62,6 +76,9 @@ pub struct TaskMeta<'a> { /// The location where the task was spawned. #[cfg_attr(not(tokio_unstable), allow(unreachable_pub, dead_code))] pub(crate) spawned_at: crate::runtime::task::SpawnLocation, + /// Optional user-defined metadata for the task. + #[cfg(tokio_unstable)] + pub(crate) user_data: UserData, pub(crate) _phantom: PhantomData<&'a ()>, } @@ -77,7 +94,31 @@ impl<'a> TaskMeta<'a> { pub fn spawned_at(&self) -> &'static std::panic::Location<'static> { self.spawned_at.0 } + + /// Return the user-defined metadata for this task if it is set and of the + /// correct type. + #[cfg(tokio_unstable)] + pub fn get_data(&self) -> UserData { + self.user_data + } } /// Runs on specific task-related events pub(crate) type TaskCallback = std::sync::Arc) + Send + Sync>; + +/// Runs on task-spawn events, and can optionally return user-defined metadata +/// to attach to the task, which are accessible in subsequent hooks. +#[cfg(tokio_unstable)] +pub(crate) type TaskSpawnCallback = std::sync::Arc) -> UserData + Send + Sync>; + +/// User data that can be attached to a task when spawning. +/// +/// This type alias provides a cleaner interface for the user data parameter +/// used throughout the task spawning system when the `tokio_unstable` feature +/// is enabled. +#[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] +pub(crate) type UserDataValue = &'static (dyn Any + Send + Sync); +#[cfg(all(tokio_unstable, not(feature = "rt-multi-thread")))] +pub(crate) type UserDataValue = &'static dyn Any; +#[cfg(tokio_unstable)] +pub(crate) type UserData = Option; diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index 7a10ac4a3b8..65012f3d584 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -401,10 +401,14 @@ impl Runtime { T: 'static + Send + Future, T::Output: 'static + Send, { - let (handle, notified) = - self.0 - .owned - .bind(future, self.clone(), Id::next(), SpawnLocation::capture()); + let (handle, notified) = self.0.owned.bind( + future, + self.clone(), + Id::next(), + SpawnLocation::capture(), + #[cfg(tokio_unstable)] + None, + ); if let Some(notified) = notified { self.schedule(notified); diff --git a/tokio/src/task/builder.rs b/tokio/src/task/builder.rs index 467a700646e..55ceec4a467 100644 --- a/tokio/src/task/builder.rs +++ b/tokio/src/task/builder.rs @@ -1,4 +1,6 @@ #![allow(unreachable_pub)] +#[cfg(tokio_unstable)] +use crate::runtime::{UserData, UserDataValue}; use crate::{ runtime::{Handle, BOX_FUTURE_THRESHOLD}, task::{JoinHandle, LocalSet}, @@ -62,6 +64,8 @@ use std::{future::Future, io, mem}; #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))] pub struct Builder<'a> { name: Option<&'a str>, + #[cfg(tokio_unstable)] + user_data: UserData, } impl<'a> Builder<'a> { @@ -72,7 +76,20 @@ impl<'a> Builder<'a> { /// Assigns a name to the task which will be spawned. pub fn name(&self, name: &'a str) -> Self { - Self { name: Some(name) } + Self { + name: Some(name), + #[cfg(tokio_unstable)] + user_data: self.user_data, + } + } + + /// Assigns user data to the task which will be spawned. + #[cfg(tokio_unstable)] + pub fn data(&self, data: UserDataValue) -> Self { + Self { + name: self.name, + user_data: Some(data), + } } /// Spawns a task with this builder's settings on the current runtime. @@ -91,9 +108,19 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner_with_user_data( + Box::pin(future), + SpawnMeta::new(self.name, fut_size), + #[cfg(tokio_unstable)] + self.user_data, + ) } else { - super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner_with_user_data( + future, + SpawnMeta::new(self.name, fut_size), + #[cfg(tokio_unstable)] + self.user_data, + ) }) } diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 021e6277534..fa1539b09f0 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1050,6 +1050,8 @@ impl Context { self.shared.clone(), id, SpawnLocation::capture(), + #[cfg(tokio_unstable)] + None, ) }; diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index 8ed288034d9..22088ee496f 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -1,3 +1,5 @@ +#[cfg(tokio_unstable)] +use crate::runtime::UserData; use crate::runtime::BOX_FUTURE_THRESHOLD; use crate::task::JoinHandle; use crate::util::trace::SpawnMeta; @@ -181,6 +183,20 @@ cfg_rt! { #[track_caller] pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + spawn_inner_with_user_data(future, meta, #[cfg(tokio_unstable)] None) + } + + #[track_caller] + pub(super) fn spawn_inner_with_user_data( + future: T, + meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] + user_data: UserData, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, @@ -202,6 +218,13 @@ cfg_rt! { let id = task::Id::next(); let task = crate::util::trace::task(future, "task", meta, id.as_u64()); + #[cfg(tokio_unstable)] + match context::with_current(|handle| handle.spawn_with_user_data(task, id, meta.spawned_at, #[cfg(tokio_unstable)] user_data)) { + Ok(join_handle) => join_handle, + Err(e) => panic!("{}", e), + } + + #[cfg(not(tokio_unstable))] match context::with_current(|handle| handle.spawn(task, id, meta.spawned_at)) { Ok(join_handle) => join_handle, Err(e) => panic!("{}", e), diff --git a/tokio/tests/task_hooks.rs b/tokio/tests/task_hooks.rs index 42bb3fd946c..f64784c11ee 100644 --- a/tokio/tests/task_hooks.rs +++ b/tokio/tests/task_hooks.rs @@ -1,6 +1,8 @@ #![warn(rust_2018_idioms)] #![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))] +#[cfg(feature = "rt-multi-thread")] +use std::any::Any; use std::collections::HashSet; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; @@ -23,6 +25,7 @@ fn spawn_task_hook_fires() { ids2.lock().unwrap().insert(data.id()); count2.fetch_add(1, Ordering::SeqCst); + None }) .build() .unwrap(); @@ -85,11 +88,11 @@ fn task_hook_spawn_location_current_thread() { "(current_thread) on_task_spawn", &spawns, )) - .on_before_task_poll(mk_spawn_location_hook( + .on_before_task_poll(mk_poll_location_hook( "(current_thread) on_before_task_poll", &poll_starts, )) - .on_after_task_poll(mk_spawn_location_hook( + .on_after_task_poll(mk_poll_location_hook( "(current_thread) on_after_task_poll", &poll_ends, )) @@ -136,11 +139,11 @@ fn task_hook_spawn_location_multi_thread() { "(multi_thread) on_task_spawn", &spawns, )) - .on_before_task_poll(mk_spawn_location_hook( + .on_before_task_poll(mk_poll_location_hook( "(multi_thread) on_before_task_poll", &poll_starts, )) - .on_after_task_poll(mk_spawn_location_hook( + .on_after_task_poll(mk_poll_location_hook( "(multi_thread) on_after_task_poll", &poll_ends, )) @@ -174,9 +177,34 @@ fn task_hook_spawn_location_multi_thread() { assert_eq!(poll_starts, poll_ends.fetch_add(0, Ordering::SeqCst)); } +#[cfg(feature = "rt-multi-thread")] +type UserData = Option<&'static (dyn Any + Send + Sync)>; +#[cfg(not(feature = "rt-multi-thread"))] +type UserData = Option<&'static dyn Any>; + fn mk_spawn_location_hook( event: &'static str, count: &Arc, +) -> impl Fn(&tokio::runtime::TaskMeta<'_>) -> UserData { + let count = Arc::clone(count); + move |data| { + eprintln!("{event} ({:?}): {:?}", data.id(), data.spawned_at()); + // Assert that the spawn location is in this file. + // Don't make assertions about line number/column here, as these + // may change as new code is added to the test file... + assert_eq!( + data.spawned_at().file(), + file!(), + "incorrect spawn location in {event} hook", + ); + count.fetch_add(1, Ordering::SeqCst); + None + } +} + +fn mk_poll_location_hook( + event: &'static str, + count: &Arc, ) -> impl Fn(&tokio::runtime::TaskMeta<'_>) { let count = Arc::clone(count); move |data| {