38
38
html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
39
39
) ]
40
40
41
+ use std:: collections:: HashMap ;
41
42
use std:: fmt;
42
43
use std:: marker:: PhantomData ;
43
44
use std:: panic:: { RefUnwindSafe , UnwindSafe } ;
44
45
use std:: rc:: Rc ;
45
- use std:: sync:: atomic:: { AtomicBool , AtomicPtr , Ordering } ;
46
+ use std:: sync:: atomic:: { AtomicBool , AtomicPtr , AtomicUsize , Ordering } ;
46
47
use std:: sync:: { Arc , Mutex , RwLock , TryLockError } ;
47
48
use std:: task:: { Poll , Waker } ;
49
+ use std:: thread:: { self , ThreadId } ;
48
50
49
51
use async_task:: { Builder , Runnable } ;
50
52
use concurrent_queue:: ConcurrentQueue ;
@@ -369,8 +371,32 @@ impl<'a> Executor<'a> {
369
371
fn schedule ( & self ) -> impl Fn ( Runnable ) + Send + Sync + ' static {
370
372
let state = self . state_as_arc ( ) ;
371
373
372
- // TODO: If possible, push into the current local queue and notify the ticker.
373
- move |runnable| {
374
+ move |mut runnable| {
375
+ // If possible, push into the current local queue and notify the ticker.
376
+ if let Some ( local_queue) = state
377
+ . local_queues
378
+ . read ( )
379
+ . unwrap ( )
380
+ . get ( & thread:: current ( ) . id ( ) )
381
+ . and_then ( |list| list. first ( ) )
382
+ {
383
+ match local_queue. queue . push ( runnable) {
384
+ Ok ( ( ) ) => {
385
+ if let Some ( waker) = state
386
+ . sleepers
387
+ . lock ( )
388
+ . unwrap ( )
389
+ . notify_runner ( local_queue. runner_id )
390
+ {
391
+ waker. wake ( ) ;
392
+ }
393
+ return ;
394
+ }
395
+
396
+ Err ( r) => runnable = r. into_inner ( ) ,
397
+ }
398
+ }
399
+
374
400
state. queue . push ( runnable) . unwrap ( ) ;
375
401
state. notify ( ) ;
376
402
}
@@ -687,7 +713,9 @@ struct State {
687
713
queue : ConcurrentQueue < Runnable > ,
688
714
689
715
/// Local queues created by runners.
690
- local_queues : RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ,
716
+ ///
717
+ /// These are keyed by the thread that the runner originated in.
718
+ local_queues : RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ,
691
719
692
720
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
693
721
notified : AtomicBool ,
@@ -704,7 +732,7 @@ impl State {
704
732
fn new ( ) -> State {
705
733
State {
706
734
queue : ConcurrentQueue :: unbounded ( ) ,
707
- local_queues : RwLock :: new ( Vec :: new ( ) ) ,
735
+ local_queues : RwLock :: new ( HashMap :: new ( ) ) ,
708
736
notified : AtomicBool :: new ( true ) ,
709
737
sleepers : Mutex :: new ( Sleepers {
710
738
count : 0 ,
@@ -739,36 +767,57 @@ struct Sleepers {
739
767
/// IDs and wakers of sleeping unnotified tickers.
740
768
///
741
769
/// A sleeping ticker is notified when its waker is missing from this list.
742
- wakers : Vec < ( usize , Waker ) > ,
770
+ wakers : Vec < Sleeper > ,
743
771
744
772
/// Reclaimed IDs.
745
773
free_ids : Vec < usize > ,
746
774
}
747
775
776
+ /// A single sleeping ticker.
777
+ struct Sleeper {
778
+ /// ID of the sleeping ticker.
779
+ id : usize ,
780
+
781
+ /// Waker associated with this ticker.
782
+ waker : Waker ,
783
+
784
+ /// Specific runner ID for targeted wakeups.
785
+ runner : Option < usize > ,
786
+ }
787
+
748
788
impl Sleepers {
749
789
/// Inserts a new sleeping ticker.
750
- fn insert ( & mut self , waker : & Waker ) -> usize {
790
+ fn insert ( & mut self , waker : & Waker , runner : Option < usize > ) -> usize {
751
791
let id = match self . free_ids . pop ( ) {
752
792
Some ( id) => id,
753
793
None => self . count + 1 ,
754
794
} ;
755
795
self . count += 1 ;
756
- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
796
+ self . wakers . push ( Sleeper {
797
+ id,
798
+ waker : waker. clone ( ) ,
799
+ runner,
800
+ } ) ;
757
801
id
758
802
}
759
803
760
804
/// Re-inserts a sleeping ticker's waker if it was notified.
761
805
///
762
806
/// Returns `true` if the ticker was notified.
763
- fn update ( & mut self , id : usize , waker : & Waker ) -> bool {
807
+ fn update ( & mut self , id : usize , waker : & Waker , runner : Option < usize > ) -> bool {
764
808
for item in & mut self . wakers {
765
- if item. 0 == id {
766
- item. 1 . clone_from ( waker) ;
809
+ if item. id == id {
810
+ debug_assert_eq ! ( item. runner, runner) ;
811
+ item. waker . clone_from ( waker) ;
767
812
return false ;
768
813
}
769
814
}
770
815
771
- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
816
+ self . wakers . push ( Sleeper {
817
+ id,
818
+ waker : waker. clone ( ) ,
819
+ runner,
820
+ } ) ;
772
821
true
773
822
}
774
823
@@ -780,7 +829,7 @@ impl Sleepers {
780
829
self . free_ids . push ( id) ;
781
830
782
831
for i in ( 0 ..self . wakers . len ( ) ) . rev ( ) {
783
- if self . wakers [ i] . 0 == id {
832
+ if self . wakers [ i] . id == id {
784
833
self . wakers . remove ( i) ;
785
834
return false ;
786
835
}
@@ -798,7 +847,20 @@ impl Sleepers {
798
847
/// If a ticker was notified already or there are no tickers, `None` will be returned.
799
848
fn notify ( & mut self ) -> Option < Waker > {
800
849
if self . wakers . len ( ) == self . count {
801
- self . wakers . pop ( ) . map ( |item| item. 1 )
850
+ self . wakers . pop ( ) . map ( |item| item. waker )
851
+ } else {
852
+ None
853
+ }
854
+ }
855
+
856
+ /// Notify a specific waker that was previously sleeping.
857
+ fn notify_runner ( & mut self , runner : usize ) -> Option < Waker > {
858
+ if let Some ( posn) = self
859
+ . wakers
860
+ . iter ( )
861
+ . position ( |sleeper| sleeper. runner == Some ( runner) )
862
+ {
863
+ Some ( self . wakers . swap_remove ( posn) . waker )
802
864
} else {
803
865
None
804
866
}
@@ -817,12 +879,28 @@ struct Ticker<'a> {
817
879
/// 2a) Sleeping and unnotified.
818
880
/// 2b) Sleeping and notified.
819
881
sleeping : usize ,
882
+
883
+ /// Unique runner ID, if this is a runner.
884
+ runner : Option < usize > ,
820
885
}
821
886
822
887
impl Ticker < ' _ > {
823
888
/// Creates a ticker.
824
889
fn new ( state : & State ) -> Ticker < ' _ > {
825
- Ticker { state, sleeping : 0 }
890
+ Ticker {
891
+ state,
892
+ sleeping : 0 ,
893
+ runner : None ,
894
+ }
895
+ }
896
+
897
+ /// Creates a ticker for a runner.
898
+ fn for_runner ( state : & State , runner : usize ) -> Ticker < ' _ > {
899
+ Ticker {
900
+ state,
901
+ sleeping : 0 ,
902
+ runner : Some ( runner) ,
903
+ }
826
904
}
827
905
828
906
/// Moves the ticker into sleeping and unnotified state.
@@ -834,12 +912,12 @@ impl Ticker<'_> {
834
912
match self . sleeping {
835
913
// Move to sleeping state.
836
914
0 => {
837
- self . sleeping = sleepers. insert ( waker) ;
915
+ self . sleeping = sleepers. insert ( waker, self . runner ) ;
838
916
}
839
917
840
918
// Already sleeping, check if notified.
841
919
id => {
842
- if !sleepers. update ( id, waker) {
920
+ if !sleepers. update ( id, waker, self . runner ) {
843
921
return false ;
844
922
}
845
923
}
@@ -929,8 +1007,11 @@ struct Runner<'a> {
929
1007
/// Inner ticker.
930
1008
ticker : Ticker < ' a > ,
931
1009
1010
+ /// The ID of the thread we originated from.
1011
+ origin_id : ThreadId ,
1012
+
932
1013
/// The local queue.
933
- local : Arc < ConcurrentQueue < Runnable > > ,
1014
+ local : Arc < LocalQueue > ,
934
1015
935
1016
/// Bumped every time a runnable task is found.
936
1017
ticks : usize ,
@@ -939,16 +1020,26 @@ struct Runner<'a> {
939
1020
impl Runner < ' _ > {
940
1021
/// Creates a runner and registers it in the executor state.
941
1022
fn new ( state : & State ) -> Runner < ' _ > {
1023
+ static ID_GENERATOR : AtomicUsize = AtomicUsize :: new ( 0 ) ;
1024
+ let runner_id = ID_GENERATOR . fetch_add ( 1 , Ordering :: SeqCst ) ;
1025
+
1026
+ let origin_id = thread:: current ( ) . id ( ) ;
942
1027
let runner = Runner {
943
1028
state,
944
- ticker : Ticker :: new ( state) ,
945
- local : Arc :: new ( ConcurrentQueue :: bounded ( 512 ) ) ,
1029
+ ticker : Ticker :: for_runner ( state, runner_id) ,
1030
+ local : Arc :: new ( LocalQueue {
1031
+ queue : ConcurrentQueue :: bounded ( 512 ) ,
1032
+ runner_id,
1033
+ } ) ,
946
1034
ticks : 0 ,
1035
+ origin_id,
947
1036
} ;
948
1037
state
949
1038
. local_queues
950
1039
. write ( )
951
1040
. unwrap ( )
1041
+ . entry ( origin_id)
1042
+ . or_default ( )
952
1043
. push ( runner. local . clone ( ) ) ;
953
1044
runner
954
1045
}
@@ -959,13 +1050,13 @@ impl Runner<'_> {
959
1050
. ticker
960
1051
. runnable_with ( || {
961
1052
// Try the local queue.
962
- if let Ok ( r) = self . local . pop ( ) {
1053
+ if let Ok ( r) = self . local . queue . pop ( ) {
963
1054
return Some ( r) ;
964
1055
}
965
1056
966
1057
// Try stealing from the global queue.
967
1058
if let Ok ( r) = self . state . queue . pop ( ) {
968
- steal ( & self . state . queue , & self . local ) ;
1059
+ steal ( & self . state . queue , & self . local . queue ) ;
969
1060
return Some ( r) ;
970
1061
}
971
1062
@@ -977,7 +1068,8 @@ impl Runner<'_> {
977
1068
let start = rng. usize ( ..n) ;
978
1069
let iter = local_queues
979
1070
. iter ( )
980
- . chain ( local_queues. iter ( ) )
1071
+ . flat_map ( |( _, list) | list)
1072
+ . chain ( local_queues. iter ( ) . flat_map ( |( _, list) | list) )
981
1073
. skip ( start)
982
1074
. take ( n) ;
983
1075
@@ -986,8 +1078,8 @@ impl Runner<'_> {
986
1078
987
1079
// Try stealing from each local queue in the list.
988
1080
for local in iter {
989
- steal ( local, & self . local ) ;
990
- if let Ok ( r) = self . local . pop ( ) {
1081
+ steal ( & local. queue , & self . local . queue ) ;
1082
+ if let Ok ( r) = self . local . queue . pop ( ) {
991
1083
return Some ( r) ;
992
1084
}
993
1085
}
@@ -1001,7 +1093,7 @@ impl Runner<'_> {
1001
1093
1002
1094
if self . ticks % 64 == 0 {
1003
1095
// Steal tasks from the global queue to ensure fair task scheduling.
1004
- steal ( & self . state . queue , & self . local ) ;
1096
+ steal ( & self . state . queue , & self . local . queue ) ;
1005
1097
}
1006
1098
1007
1099
runnable
@@ -1015,15 +1107,26 @@ impl Drop for Runner<'_> {
1015
1107
. local_queues
1016
1108
. write ( )
1017
1109
. unwrap ( )
1110
+ . get_mut ( & self . origin_id )
1111
+ . unwrap ( )
1018
1112
. retain ( |local| !Arc :: ptr_eq ( local, & self . local ) ) ;
1019
1113
1020
1114
// Re-schedule remaining tasks in the local queue.
1021
- while let Ok ( r) = self . local . pop ( ) {
1115
+ while let Ok ( r) = self . local . queue . pop ( ) {
1022
1116
r. schedule ( ) ;
1023
1117
}
1024
1118
}
1025
1119
}
1026
1120
1121
+ /// Data associated with a local queue.
1122
+ struct LocalQueue {
1123
+ /// Concurrent queue of active tasks.
1124
+ queue : ConcurrentQueue < Runnable > ,
1125
+
1126
+ /// Unique ID associated with this runner.
1127
+ runner_id : usize ,
1128
+ }
1129
+
1027
1130
/// Steals some items from one queue into another.
1028
1131
fn steal < T > ( src : & ConcurrentQueue < T > , dest : & ConcurrentQueue < T > ) {
1029
1132
// Half of `src`'s length rounded up.
@@ -1082,14 +1185,18 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
1082
1185
}
1083
1186
1084
1187
/// Debug wrapper for the local runners.
1085
- struct LocalRunners < ' a > ( & ' a RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ) ;
1188
+ struct LocalRunners < ' a > ( & ' a RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ) ;
1086
1189
1087
1190
impl fmt:: Debug for LocalRunners < ' _ > {
1088
1191
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
1089
1192
match self . 0 . try_read ( ) {
1090
1193
Ok ( lock) => f
1091
1194
. debug_list ( )
1092
- . entries ( lock. iter ( ) . map ( |queue| queue. len ( ) ) )
1195
+ . entries (
1196
+ lock. iter ( )
1197
+ . flat_map ( |( _, list) | list)
1198
+ . map ( |queue| queue. queue . len ( ) ) ,
1199
+ )
1093
1200
. finish ( ) ,
1094
1201
Err ( TryLockError :: WouldBlock ) => f. write_str ( "<locked>" ) ,
1095
1202
Err ( TryLockError :: Poisoned ( _) ) => f. write_str ( "<poisoned>" ) ,
0 commit comments