@@ -8,6 +8,8 @@ use crate::networking::{AsyncNetworking, LocalAsyncNetworking};
88use crate :: replicated:: { RepSetup , ReplicatedPlacement } ;
99use crate :: storage:: { AsyncStorage , LocalAsyncStorage } ;
1010use futures:: future:: { Map , Shared } ;
11+ use futures:: stream:: FuturesUnordered ;
12+ use futures:: StreamExt ;
1113use std:: collections:: { HashMap , HashSet } ;
1214use std:: convert:: TryFrom ;
1315use std:: sync:: { Arc , RwLock } ;
@@ -51,7 +53,7 @@ pub(crate) fn map_receive_error<T>(_: T) -> Error {
5153}
5254
5355pub struct AsyncSessionHandle {
54- pub tasks : Arc < RwLock < Vec < crate :: execution :: AsyncTask > > > ,
56+ pub tasks : Arc < RwLock < FuturesUnordered < AsyncTask > > > ,
5557}
5658
5759impl AsyncSessionHandle {
@@ -63,18 +65,10 @@ impl AsyncSessionHandle {
6365
6466 pub async fn join_on_first_error ( self ) -> anyhow:: Result < ( ) > {
6567 use crate :: error:: Error :: { OperandUnavailable , ResultUnused } ;
66- // use futures::StreamExt;
6768
6869 let mut tasks_guard = self . tasks . write ( ) . unwrap ( ) ;
69- // TODO (lvorona): should really find a way to use FuturesUnordered here
70- // let mut tasks = (*tasks_guard)
71- // .into_iter()
72- // .collect::<futures::stream::FuturesUnordered<_>>();
73-
74- let mut tasks = tasks_guard. iter_mut ( ) ;
75-
76- while let Some ( x) = tasks. next ( ) {
77- let x = x. await ;
70+ let mut maybe_error = None ;
71+ while let Some ( x) = tasks_guard. next ( ) . await {
7872 match x {
7973 Ok ( Ok ( _) ) => {
8074 continue ;
@@ -87,26 +81,30 @@ impl AsyncSessionHandle {
8781 OperandUnavailable => continue ,
8882 ResultUnused => continue ,
8983 _ => {
90- for task in tasks {
91- task. abort ( ) ;
92- }
93- return Err ( anyhow:: Error :: from ( e) ) ;
84+ maybe_error = Some ( Err ( anyhow:: Error :: from ( e) ) ) ;
85+ break ;
9486 }
9587 }
9688 }
9789 Err ( e) => {
9890 if e. is_cancelled ( ) {
9991 continue ;
10092 } else if e. is_panic ( ) {
101- for task in tasks {
102- task. abort ( ) ;
103- }
104- return Err ( anyhow:: Error :: from ( e) ) ;
93+ maybe_error = Some ( Err ( anyhow:: Error :: from ( e) ) ) ;
94+ break ;
10595 }
10696 }
10797 }
10898 }
109- Ok ( ( ) )
99+
100+ if let Some ( e) = maybe_error {
101+ for task in tasks_guard. iter_mut ( ) {
102+ task. abort ( ) ;
103+ }
104+ e
105+ } else {
106+ Ok ( ( ) )
107+ }
110108 }
111109}
112110
@@ -118,7 +116,7 @@ pub struct AsyncSession {
118116 pub role_assignments : Arc < HashMap < Role , Identity > > ,
119117 pub networking : AsyncNetworkingImpl ,
120118 pub storage : AsyncStorageImpl ,
121- pub tasks : Arc < RwLock < Vec < crate :: execution:: AsyncTask > > > ,
119+ pub tasks : Arc < RwLock < FuturesUnordered < crate :: execution:: AsyncTask > > > ,
122120}
123121
124122impl AsyncSession {
@@ -178,7 +176,7 @@ impl AsyncSession {
178176 map_send_result ( sender. send ( value) ) ?;
179177 Ok ( ( ) )
180178 } ) ;
181- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
179+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
182180 tasks. push ( task) ;
183181
184182 Ok ( receiver)
@@ -216,7 +214,7 @@ impl AsyncSession {
216214 map_send_result ( sender. send ( result. into ( ) ) ) ?;
217215 Ok ( ( ) )
218216 } ) ;
219- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
217+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
220218 tasks. push ( task) ;
221219
222220 Ok ( receiver)
@@ -244,7 +242,7 @@ impl AsyncSession {
244242 map_send_result ( sender. send ( value) ) ?;
245243 Ok ( ( ) )
246244 } ) ;
247- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
245+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
248246 tasks. push ( task) ;
249247
250248 Ok ( receiver)
@@ -279,7 +277,7 @@ impl AsyncSession {
279277 map_send_result ( sender. send ( result. into ( ) ) ) ?;
280278 Ok ( ( ) )
281279 } ) ;
282- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
280+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
283281 tasks. push ( task) ;
284282
285283 Ok ( receiver)
@@ -638,8 +636,11 @@ impl AsyncTestRuntime {
638636 session_handles. push ( AsyncSessionHandle :: for_session ( & moose_session) )
639637 }
640638
641- for handle in session_handles {
642- let result = rt. block_on ( handle. join_on_first_error ( ) ) ;
639+ let mut futures: FuturesUnordered < _ > = session_handles
640+ . into_iter ( )
641+ . map ( |h| h. join_on_first_error ( ) )
642+ . collect ( ) ;
643+ while let Some ( result) = rt. block_on ( futures. next ( ) ) {
643644 if let Err ( e) = result {
644645 return Err ( Error :: TestRuntime ( e. to_string ( ) ) ) ;
645646 }
0 commit comments