|
20 | 20 |
|
21 | 21 | #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] |
22 | 22 |
|
| 23 | +use std::cell::RefCell; |
23 | 24 | use std::fmt; |
24 | 25 | use std::future::Future; |
25 | 26 | use std::marker::PhantomData; |
@@ -229,29 +230,56 @@ impl<'a> Executor<'a> { |
229 | 230 | let runner = Runner::new(self.state()); |
230 | 231 | let mut rng = fastrand::Rng::new(); |
231 | 232 |
|
232 | | - // A future that runs tasks forever. |
233 | | - let run_forever = async { |
234 | | - loop { |
235 | | - for _ in 0..200 { |
236 | | - let runnable = runner.runnable(&mut rng).await; |
237 | | - runnable.run(); |
238 | | - } |
239 | | - future::yield_now().await; |
240 | | - } |
241 | | - }; |
| 233 | + // Set the local queue while we're running. |
| 234 | + LocalQueue::set(self.state(), &runner.local, { |
| 235 | + let runner = &runner; |
| 236 | + async move { |
| 237 | + // A future that runs tasks forever. |
| 238 | + let run_forever = async { |
| 239 | + loop { |
| 240 | + for _ in 0..200 { |
| 241 | + let runnable = runner.runnable(&mut rng).await; |
| 242 | + runnable.run(); |
| 243 | + } |
| 244 | + future::yield_now().await; |
| 245 | + } |
| 246 | + }; |
242 | 247 |
|
243 | | - // Run `future` and `run_forever` concurrently until `future` completes. |
244 | | - future.or(run_forever).await |
| 248 | + // Run `future` and `run_forever` concurrently until `future` completes. |
| 249 | + future.or(run_forever).await |
| 250 | + } |
| 251 | + }) |
| 252 | + .await |
245 | 253 | } |
246 | 254 |
|
247 | 255 | /// Returns a function that schedules a runnable task when it gets woken up. |
248 | 256 | fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static { |
249 | 257 | let state = self.state().clone(); |
250 | 258 |
|
251 | | - // TODO(stjepang): If possible, push into the current local queue and notify the ticker. |
| 259 | + // If possible, push into the current local queue and notify the ticker. |
252 | 260 | move |runnable| { |
253 | | - state.queue.push(runnable).unwrap(); |
254 | | - state.notify(); |
| 261 | + let mut runnable = Some(runnable); |
| 262 | + |
| 263 | + // Try to push into the local queue. |
| 264 | + LocalQueue::with(|local_queue| { |
| 265 | + // Make sure that we don't accidentally push to an executor that isn't ours. |
| 266 | + if !std::ptr::eq(local_queue.state, &*state) { |
| 267 | + return; |
| 268 | + } |
| 269 | + |
| 270 | + if let Err(e) = local_queue.queue.push(runnable.take().unwrap()) { |
| 271 | + runnable = Some(e.into_inner()); |
| 272 | + return; |
| 273 | + } |
| 274 | + |
| 275 | + local_queue.waker.wake_by_ref(); |
| 276 | + }); |
| 277 | + |
| 278 | + // If the local queue push failed, just push to the global queue. |
| 279 | + if let Some(runnable) = runnable { |
| 280 | + state.queue.push(runnable).unwrap(); |
| 281 | + state.notify(); |
| 282 | + } |
255 | 283 | } |
256 | 284 | } |
257 | 285 |
|
@@ -819,6 +847,97 @@ impl Drop for Runner<'_> { |
819 | 847 | } |
820 | 848 | } |
821 | 849 |
|
| 850 | +/// The state of the currently running local queue. |
| 851 | +struct LocalQueue { |
| 852 | + /// The pointer to the state of the executor. |
| 853 | + /// |
| 854 | + /// Used to make sure we don't push runnables to the wrong executor. |
| 855 | + state: *const State, |
| 856 | + |
| 857 | + /// The concurrent queue. |
| 858 | + queue: Arc<ConcurrentQueue<Runnable>>, |
| 859 | + |
| 860 | + /// The waker for the runnable. |
| 861 | + waker: Waker, |
| 862 | +} |
| 863 | + |
| 864 | +impl LocalQueue { |
| 865 | + /// Run a function with the current local queue. |
| 866 | + fn with<R>(f: impl FnOnce(&LocalQueue) -> R) -> Option<R> { |
| 867 | + std::thread_local! { |
| 868 | + /// The current local queue. |
| 869 | + static LOCAL_QUEUE: RefCell<Option<LocalQueue>> = RefCell::new(None); |
| 870 | + } |
| 871 | + |
| 872 | + impl LocalQueue { |
| 873 | + /// Run a function with a set local queue. |
| 874 | + async fn set<F>( |
| 875 | + state: &State, |
| 876 | + queue: &Arc<ConcurrentQueue<Runnable>>, |
| 877 | + fut: F, |
| 878 | + ) -> F::Output |
| 879 | + where |
| 880 | + F: Future, |
| 881 | + { |
| 882 | + // Store the local queue and the current waker. |
| 883 | + let mut old = with_waker(|waker| { |
| 884 | + LOCAL_QUEUE.with(move |slot| { |
| 885 | + slot.borrow_mut().replace(LocalQueue { |
| 886 | + state: state as *const State, |
| 887 | + queue: queue.clone(), |
| 888 | + waker: waker.clone(), |
| 889 | + }) |
| 890 | + }) |
| 891 | + }) |
| 892 | + .await; |
| 893 | + |
| 894 | + // Restore the old local queue on drop. |
| 895 | + let _guard = CallOnDrop(move || { |
| 896 | + let old = old.take(); |
| 897 | + let _ = LOCAL_QUEUE.try_with(move |slot| { |
| 898 | + *slot.borrow_mut() = old; |
| 899 | + }); |
| 900 | + }); |
| 901 | + |
| 902 | + // Pin the future. |
| 903 | + futures_lite::pin!(fut); |
| 904 | + |
| 905 | + // Run it such that the waker is updated every time it's polled. |
| 906 | + future::poll_fn(move |cx| { |
| 907 | + LOCAL_QUEUE |
| 908 | + .try_with({ |
| 909 | + let waker = cx.waker(); |
| 910 | + move |slot| { |
| 911 | + let mut slot = slot.borrow_mut(); |
| 912 | + let qaw = slot.as_mut().expect("missing local queue"); |
| 913 | + |
| 914 | + // If we've been replaced, just ignore the slot. |
| 915 | + if !Arc::ptr_eq(&qaw.queue, queue) { |
| 916 | + return; |
| 917 | + } |
| 918 | + |
| 919 | + // Update the waker, if it has changed. |
| 920 | + if !qaw.waker.will_wake(waker) { |
| 921 | + qaw.waker = waker.clone(); |
| 922 | + } |
| 923 | + } |
| 924 | + }) |
| 925 | + .ok(); |
| 926 | + |
| 927 | + // Poll the future. |
| 928 | + fut.as_mut().poll(cx) |
| 929 | + }) |
| 930 | + .await |
| 931 | + } |
| 932 | + } |
| 933 | + |
| 934 | + LOCAL_QUEUE |
| 935 | + .try_with(|local_queue| local_queue.borrow().as_ref().map(f)) |
| 936 | + .ok() |
| 937 | + .flatten() |
| 938 | + } |
| 939 | +} |
| 940 | + |
822 | 941 | /// Steals some items from one queue into another. |
823 | 942 | fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) { |
824 | 943 | // Half of `src`'s length rounded up. |
@@ -911,10 +1030,19 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_ |
911 | 1030 | } |
912 | 1031 |
|
913 | 1032 | /// Runs a closure when dropped. |
914 | | -struct CallOnDrop<F: Fn()>(F); |
| 1033 | +struct CallOnDrop<F: FnMut()>(F); |
915 | 1034 |
|
916 | | -impl<F: Fn()> Drop for CallOnDrop<F> { |
| 1035 | +impl<F: FnMut()> Drop for CallOnDrop<F> { |
917 | 1036 | fn drop(&mut self) { |
918 | 1037 | (self.0)(); |
919 | 1038 | } |
920 | 1039 | } |
| 1040 | + |
| 1041 | +/// Run a closure with the current waker. |
| 1042 | +fn with_waker<F: FnOnce(&Waker) -> R, R>(f: F) -> impl Future<Output = R> { |
| 1043 | + let mut f = Some(f); |
| 1044 | + future::poll_fn(move |cx| { |
| 1045 | + let f = f.take().unwrap(); |
| 1046 | + Poll::Ready(f(cx.waker())) |
| 1047 | + }) |
| 1048 | +} |
0 commit comments