Skip to content

Commit e5faacc

Browse files
committed
feat: Thread-local queue push take 3
This commit attempts to re-introduce the thread-local optimization. It stores the local queues in a multiplex hash map keyed by the thread ID that it started in. It also sets it up so the thread can be woken up by a unique runner ID. cc #64 Signed-off-by: John Nunley <[email protected]>
1 parent ef512cb commit e5faacc

File tree

1 file changed

+136
-29
lines changed

1 file changed

+136
-29
lines changed

src/lib.rs

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838
html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
3939
)]
4040

41+
use std::collections::HashMap;
4142
use std::fmt;
4243
use std::marker::PhantomData;
4344
use std::panic::{RefUnwindSafe, UnwindSafe};
4445
use std::rc::Rc;
45-
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
46+
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
4647
use std::sync::{Arc, Mutex, RwLock, TryLockError};
4748
use std::task::{Poll, Waker};
49+
use std::thread::{self, ThreadId};
4850

4951
use async_task::{Builder, Runnable};
5052
use concurrent_queue::ConcurrentQueue;
@@ -369,8 +371,32 @@ impl<'a> Executor<'a> {
369371
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
370372
let state = self.state_as_arc();
371373

372-
// TODO: If possible, push into the current local queue and notify the ticker.
373-
move |runnable| {
374+
move |mut runnable| {
375+
// If possible, push into the current local queue and notify the ticker.
376+
if let Some(local_queue) = state
377+
.local_queues
378+
.read()
379+
.unwrap()
380+
.get(&thread::current().id())
381+
.and_then(|list| list.first())
382+
{
383+
match local_queue.queue.push(runnable) {
384+
Ok(()) => {
385+
if let Some(waker) = state
386+
.sleepers
387+
.lock()
388+
.unwrap()
389+
.notify_runner(local_queue.runner_id)
390+
{
391+
waker.wake();
392+
}
393+
return;
394+
}
395+
396+
Err(r) => runnable = r.into_inner(),
397+
}
398+
}
399+
374400
state.queue.push(runnable).unwrap();
375401
state.notify();
376402
}
@@ -687,7 +713,9 @@ struct State {
687713
queue: ConcurrentQueue<Runnable>,
688714

689715
/// Local queues created by runners.
690-
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
716+
///
717+
/// These are keyed by the thread that the runner originated in.
718+
local_queues: RwLock<HashMap<ThreadId, Vec<Arc<LocalQueue>>>>,
691719

692720
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
693721
notified: AtomicBool,
@@ -704,7 +732,7 @@ impl State {
704732
fn new() -> State {
705733
State {
706734
queue: ConcurrentQueue::unbounded(),
707-
local_queues: RwLock::new(Vec::new()),
735+
local_queues: RwLock::new(HashMap::new()),
708736
notified: AtomicBool::new(true),
709737
sleepers: Mutex::new(Sleepers {
710738
count: 0,
@@ -739,36 +767,57 @@ struct Sleepers {
739767
/// IDs and wakers of sleeping unnotified tickers.
740768
///
741769
/// A sleeping ticker is notified when its waker is missing from this list.
742-
wakers: Vec<(usize, Waker)>,
770+
wakers: Vec<Sleeper>,
743771

744772
/// Reclaimed IDs.
745773
free_ids: Vec<usize>,
746774
}
747775

776+
/// A single sleeping ticker.
777+
struct Sleeper {
778+
/// ID of the sleeping ticker.
779+
id: usize,
780+
781+
/// Waker associated with this ticker.
782+
waker: Waker,
783+
784+
/// Specific runner ID for targeted wakeups.
785+
runner: Option<usize>,
786+
}
787+
748788
impl Sleepers {
749789
/// Inserts a new sleeping ticker.
750-
fn insert(&mut self, waker: &Waker) -> usize {
790+
fn insert(&mut self, waker: &Waker, runner: Option<usize>) -> usize {
751791
let id = match self.free_ids.pop() {
752792
Some(id) => id,
753793
None => self.count + 1,
754794
};
755795
self.count += 1;
756-
self.wakers.push((id, waker.clone()));
796+
self.wakers.push(Sleeper {
797+
id,
798+
waker: waker.clone(),
799+
runner,
800+
});
757801
id
758802
}
759803

760804
/// Re-inserts a sleeping ticker's waker if it was notified.
761805
///
762806
/// Returns `true` if the ticker was notified.
763-
fn update(&mut self, id: usize, waker: &Waker) -> bool {
807+
fn update(&mut self, id: usize, waker: &Waker, runner: Option<usize>) -> bool {
764808
for item in &mut self.wakers {
765-
if item.0 == id {
766-
item.1.clone_from(waker);
809+
if item.id == id {
810+
debug_assert_eq!(item.runner, runner);
811+
item.waker.clone_from(waker);
767812
return false;
768813
}
769814
}
770815

771-
self.wakers.push((id, waker.clone()));
816+
self.wakers.push(Sleeper {
817+
id,
818+
waker: waker.clone(),
819+
runner,
820+
});
772821
true
773822
}
774823

@@ -780,7 +829,7 @@ impl Sleepers {
780829
self.free_ids.push(id);
781830

782831
for i in (0..self.wakers.len()).rev() {
783-
if self.wakers[i].0 == id {
832+
if self.wakers[i].id == id {
784833
self.wakers.remove(i);
785834
return false;
786835
}
@@ -798,7 +847,20 @@ impl Sleepers {
798847
/// If a ticker was notified already or there are no tickers, `None` will be returned.
799848
fn notify(&mut self) -> Option<Waker> {
800849
if self.wakers.len() == self.count {
801-
self.wakers.pop().map(|item| item.1)
850+
self.wakers.pop().map(|item| item.waker)
851+
} else {
852+
None
853+
}
854+
}
855+
856+
/// Notify a specific waker that was previously sleeping.
857+
fn notify_runner(&mut self, runner: usize) -> Option<Waker> {
858+
if let Some(posn) = self
859+
.wakers
860+
.iter()
861+
.position(|sleeper| sleeper.runner == Some(runner))
862+
{
863+
Some(self.wakers.swap_remove(posn).waker)
802864
} else {
803865
None
804866
}
@@ -817,12 +879,28 @@ struct Ticker<'a> {
817879
/// 2a) Sleeping and unnotified.
818880
/// 2b) Sleeping and notified.
819881
sleeping: usize,
882+
883+
/// Unique runner ID, if this is a runner.
884+
runner: Option<usize>,
820885
}
821886

822887
impl Ticker<'_> {
823888
/// Creates a ticker.
824889
fn new(state: &State) -> Ticker<'_> {
825-
Ticker { state, sleeping: 0 }
890+
Ticker {
891+
state,
892+
sleeping: 0,
893+
runner: None,
894+
}
895+
}
896+
897+
/// Creates a ticker for a runner.
898+
fn for_runner(state: &State, runner: usize) -> Ticker<'_> {
899+
Ticker {
900+
state,
901+
sleeping: 0,
902+
runner: Some(runner),
903+
}
826904
}
827905

828906
/// Moves the ticker into sleeping and unnotified state.
@@ -834,12 +912,12 @@ impl Ticker<'_> {
834912
match self.sleeping {
835913
// Move to sleeping state.
836914
0 => {
837-
self.sleeping = sleepers.insert(waker);
915+
self.sleeping = sleepers.insert(waker, self.runner);
838916
}
839917

840918
// Already sleeping, check if notified.
841919
id => {
842-
if !sleepers.update(id, waker) {
920+
if !sleepers.update(id, waker, self.runner) {
843921
return false;
844922
}
845923
}
@@ -929,8 +1007,11 @@ struct Runner<'a> {
9291007
/// Inner ticker.
9301008
ticker: Ticker<'a>,
9311009

1010+
/// The ID of the thread we originated from.
1011+
origin_id: ThreadId,
1012+
9321013
/// The local queue.
933-
local: Arc<ConcurrentQueue<Runnable>>,
1014+
local: Arc<LocalQueue>,
9341015

9351016
/// Bumped every time a runnable task is found.
9361017
ticks: usize,
@@ -939,16 +1020,26 @@ struct Runner<'a> {
9391020
impl Runner<'_> {
9401021
/// Creates a runner and registers it in the executor state.
9411022
fn new(state: &State) -> Runner<'_> {
1023+
static ID_GENERATOR: AtomicUsize = AtomicUsize::new(0);
1024+
let runner_id = ID_GENERATOR.fetch_add(1, Ordering::SeqCst);
1025+
1026+
let origin_id = thread::current().id();
9421027
let runner = Runner {
9431028
state,
944-
ticker: Ticker::new(state),
945-
local: Arc::new(ConcurrentQueue::bounded(512)),
1029+
ticker: Ticker::for_runner(state, runner_id),
1030+
local: Arc::new(LocalQueue {
1031+
queue: ConcurrentQueue::bounded(512),
1032+
runner_id,
1033+
}),
9461034
ticks: 0,
1035+
origin_id,
9471036
};
9481037
state
9491038
.local_queues
9501039
.write()
9511040
.unwrap()
1041+
.entry(origin_id)
1042+
.or_default()
9521043
.push(runner.local.clone());
9531044
runner
9541045
}
@@ -959,13 +1050,13 @@ impl Runner<'_> {
9591050
.ticker
9601051
.runnable_with(|| {
9611052
// Try the local queue.
962-
if let Ok(r) = self.local.pop() {
1053+
if let Ok(r) = self.local.queue.pop() {
9631054
return Some(r);
9641055
}
9651056

9661057
// Try stealing from the global queue.
9671058
if let Ok(r) = self.state.queue.pop() {
968-
steal(&self.state.queue, &self.local);
1059+
steal(&self.state.queue, &self.local.queue);
9691060
return Some(r);
9701061
}
9711062

@@ -977,7 +1068,8 @@ impl Runner<'_> {
9771068
let start = rng.usize(..n);
9781069
let iter = local_queues
9791070
.iter()
980-
.chain(local_queues.iter())
1071+
.flat_map(|(_, list)| list)
1072+
.chain(local_queues.iter().flat_map(|(_, list)| list))
9811073
.skip(start)
9821074
.take(n);
9831075

@@ -986,8 +1078,8 @@ impl Runner<'_> {
9861078

9871079
// Try stealing from each local queue in the list.
9881080
for local in iter {
989-
steal(local, &self.local);
990-
if let Ok(r) = self.local.pop() {
1081+
steal(&local.queue, &self.local.queue);
1082+
if let Ok(r) = self.local.queue.pop() {
9911083
return Some(r);
9921084
}
9931085
}
@@ -1001,7 +1093,7 @@ impl Runner<'_> {
10011093

10021094
if self.ticks % 64 == 0 {
10031095
// Steal tasks from the global queue to ensure fair task scheduling.
1004-
steal(&self.state.queue, &self.local);
1096+
steal(&self.state.queue, &self.local.queue);
10051097
}
10061098

10071099
runnable
@@ -1015,15 +1107,26 @@ impl Drop for Runner<'_> {
10151107
.local_queues
10161108
.write()
10171109
.unwrap()
1110+
.get_mut(&self.origin_id)
1111+
.unwrap()
10181112
.retain(|local| !Arc::ptr_eq(local, &self.local));
10191113

10201114
// Re-schedule remaining tasks in the local queue.
1021-
while let Ok(r) = self.local.pop() {
1115+
while let Ok(r) = self.local.queue.pop() {
10221116
r.schedule();
10231117
}
10241118
}
10251119
}
10261120

1121+
/// Data associated with a local queue.
1122+
struct LocalQueue {
1123+
/// Concurrent queue of active tasks.
1124+
queue: ConcurrentQueue<Runnable>,
1125+
1126+
/// Unique ID associated with this runner.
1127+
runner_id: usize,
1128+
}
1129+
10271130
/// Steals some items from one queue into another.
10281131
fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
10291132
// Half of `src`'s length rounded up.
@@ -1082,14 +1185,18 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
10821185
}
10831186

10841187
/// Debug wrapper for the local runners.
1085-
struct LocalRunners<'a>(&'a RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>);
1188+
struct LocalRunners<'a>(&'a RwLock<HashMap<ThreadId, Vec<Arc<LocalQueue>>>>);
10861189

10871190
impl fmt::Debug for LocalRunners<'_> {
10881191
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
10891192
match self.0.try_read() {
10901193
Ok(lock) => f
10911194
.debug_list()
1092-
.entries(lock.iter().map(|queue| queue.len()))
1195+
.entries(
1196+
lock.iter()
1197+
.flat_map(|(_, list)| list)
1198+
.map(|queue| queue.queue.len()),
1199+
)
10931200
.finish(),
10941201
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
10951202
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),

0 commit comments

Comments
 (0)