11// Copyright © Aptos Foundation
22// SPDX-License-Identifier: Apache-2.0
33
4- use crate :: metrics:: TIMER ;
4+ use crate :: metrics:: { GAUGE , TIMER } ;
55use aptos_infallible:: Mutex ;
6- use aptos_metrics_core:: TimerHelper ;
7- use std:: sync:: mpsc:: { channel, Receiver , Sender } ;
6+ use aptos_metrics_core:: { IntGaugeHelper , TimerHelper } ;
7+ use std:: sync:: {
8+ mpsc:: { channel, Receiver , Sender } ,
9+ Arc , Condvar ,
10+ } ;
811use threadpool:: ThreadPool ;
912
1013/// A helper to send things to a thread pool for asynchronous dropping.
@@ -15,26 +18,17 @@ use threadpool::ThreadPool;
1518/// to another thing being waiting for a slot to be available.
1619pub struct AsyncConcurrentDropper {
1720 name : & ' static str ,
18- token_tx : Sender < ( ) > ,
19- token_rx : Mutex < Receiver < ( ) > > ,
21+ num_tasks_tracker : Arc < NumTasksTracker > ,
2022 /// use dedicated threadpool to minimize the possibility of dead lock
2123 thread_pool : ThreadPool ,
2224}
2325
2426impl AsyncConcurrentDropper {
25- pub fn new ( name : & ' static str , max_async_drops : usize , num_threads : usize ) -> Self {
26- let ( token_tx, token_rx) = channel ( ) ;
27- for _ in 0 ..max_async_drops {
28- token_tx
29- . send ( ( ) )
30- . expect ( "DropHelper: Failed to buffer initial tokens." ) ;
31- }
32- let thread_pool = ThreadPool :: new ( num_threads) ;
27+ pub fn new ( name : & ' static str , max_tasks : usize , num_threads : usize ) -> Self {
3328 Self {
3429 name,
35- token_tx,
36- token_rx : Mutex :: new ( token_rx) ,
37- thread_pool,
30+ num_tasks_tracker : Arc :: new ( NumTasksTracker :: new ( max_tasks) ) ,
31+ thread_pool : ThreadPool :: new ( num_threads) ,
3832 }
3933 }
4034
@@ -48,13 +42,19 @@ impl AsyncConcurrentDropper {
4842 rx
4943 }
5044
45+ pub fn wait_for_backlog_drop ( & self , no_more_than : usize ) {
46+ let _timer = TIMER . timer_with ( & [ self . name , "wait_for_backlog_drop" ] ) ;
47+ self . num_tasks_tracker . wait_for_backlog_drop ( no_more_than) ;
48+ }
49+
5150 fn schedule_drop_impl < V : Send + ' static > ( & self , v : V , notif_sender_opt : Option < Sender < ( ) > > ) {
5251 let _timer = TIMER . timer_with ( & [ self . name , "enqueue_drop" ] ) ;
52+ let num_tasks = self . num_tasks_tracker . inc ( ) ;
53+ GAUGE . set_with ( & [ self . name , "num_tasks" ] , num_tasks as i64 ) ;
5354
54- self . token_rx . lock ( ) . recv ( ) . unwrap ( ) ;
55-
56- let token_tx = self . token_tx . clone ( ) ;
5755 let name = self . name ;
56+ let num_tasks_tracker = self . num_tasks_tracker . clone ( ) ;
57+
5858 self . thread_pool . execute ( move || {
5959 let _timer = TIMER . timer_with ( & [ name, "real_drop" ] ) ;
6060
@@ -64,15 +64,54 @@ impl AsyncConcurrentDropper {
6464 sender. send ( ( ) ) . ok ( ) ;
6565 }
6666
67- token_tx . send ( ( ) ) . ok ( ) ;
67+ num_tasks_tracker . dec ( ) ;
6868 } )
6969 }
7070}
7171
72+ struct NumTasksTracker {
73+ lock : Mutex < usize > ,
74+ cvar : Condvar ,
75+ max_tasks : usize ,
76+ }
77+
78+ impl NumTasksTracker {
79+ fn new ( max_tasks : usize ) -> Self {
80+ Self {
81+ lock : Mutex :: new ( 0 ) ,
82+ cvar : Condvar :: new ( ) ,
83+ max_tasks,
84+ }
85+ }
86+
87+ fn inc ( & self ) -> usize {
88+ let mut num_tasks = self . lock . lock ( ) ;
89+ while * num_tasks >= self . max_tasks {
90+ num_tasks = self . cvar . wait ( num_tasks) . expect ( "lock poisoned." ) ;
91+ }
92+ * num_tasks += 1 ;
93+ * num_tasks
94+ }
95+
96+ fn dec ( & self ) {
97+ let mut num_tasks = self . lock . lock ( ) ;
98+ * num_tasks -= 1 ;
99+ self . cvar . notify_all ( ) ;
100+ }
101+
102+ fn wait_for_backlog_drop ( & self , no_more_than : usize ) {
103+ let mut num_tasks = self . lock . lock ( ) ;
104+ while * num_tasks > no_more_than {
105+ num_tasks = self . cvar . wait ( num_tasks) . expect ( "lock poisoned." ) ;
106+ }
107+ }
108+ }
109+
72110#[ cfg( test) ]
73111mod tests {
74112 use crate :: AsyncConcurrentDropper ;
75- use std:: { thread:: sleep, time:: Duration } ;
113+ use std:: { sync:: Arc , thread:: sleep, time:: Duration } ;
114+ use threadpool:: ThreadPool ;
76115
77116 struct SlowDropper ;
78117
@@ -117,4 +156,43 @@ mod tests {
117156 s. schedule_drop ( SlowDropper ) ;
118157 assert ! ( now. elapsed( ) < Duration :: from_millis( 400 ) ) ;
119158 }
159+
160+ fn async_wait (
161+ thread_pool : & ThreadPool ,
162+ dropper : & Arc < AsyncConcurrentDropper > ,
163+ no_more_than : usize ,
164+ ) {
165+ let dropper = Arc :: clone ( dropper) ;
166+ thread_pool. execute ( move || dropper. wait_for_backlog_drop ( no_more_than) ) ;
167+ }
168+
169+ #[ test]
170+ fn test_wait_for_backlog_drop ( ) {
171+ let s = Arc :: new ( AsyncConcurrentDropper :: new ( "test" , 8 , 4 ) ) ;
172+ let t = ThreadPool :: new ( 4 ) ;
173+ let now = std:: time:: Instant :: now ( ) ;
174+ for _ in 0 ..8 {
175+ s. schedule_drop ( SlowDropper ) ;
176+ }
177+ assert ! ( now. elapsed( ) < Duration :: from_millis( 200 ) ) ;
178+ s. wait_for_backlog_drop ( 8 ) ;
179+ assert ! ( now. elapsed( ) < Duration :: from_millis( 200 ) ) ;
180+ async_wait ( & t, & s, 8 ) ;
181+ async_wait ( & t, & s, 8 ) ;
182+ async_wait ( & t, & s, 7 ) ;
183+ async_wait ( & t, & s, 4 ) ;
184+ t. join ( ) ;
185+ assert ! ( now. elapsed( ) > Duration :: from_millis( 200 ) ) ;
186+ assert ! ( now. elapsed( ) < Duration :: from_millis( 400 ) ) ;
187+ s. wait_for_backlog_drop ( 4 ) ;
188+ assert ! ( now. elapsed( ) < Duration :: from_millis( 400 ) ) ;
189+ async_wait ( & t, & s, 3 ) ;
190+ async_wait ( & t, & s, 2 ) ;
191+ async_wait ( & t, & s, 1 ) ;
192+ t. join ( ) ;
193+ assert ! ( now. elapsed( ) > Duration :: from_millis( 400 ) ) ;
194+ assert ! ( now. elapsed( ) < Duration :: from_millis( 600 ) ) ;
195+ s. wait_for_backlog_drop ( 0 ) ;
196+ assert ! ( now. elapsed( ) < Duration :: from_millis( 600 ) ) ;
197+ }
120198}
0 commit comments