@@ -144,7 +144,7 @@ const EMPTY_PLAN_ERROR: RequestError = RequestError::EmptyPlan;
144144pub ( crate ) async fn execute < QueryFut , ResT > (
145145 policy : & dyn SpeculativeExecutionPolicy ,
146146 context : & Context ,
147- query_runner_generator : impl Fn ( bool ) -> QueryFut ,
147+ mut query_runner_generator : impl FnMut ( bool ) -> QueryFut ,
148148) -> Result < ( ResT , Coordinator ) , RequestError >
149149where
150150 QueryFut : Future < Output = Option < Result < ( ResT , Coordinator ) , RequestError > > > ,
@@ -195,3 +195,131 @@ where
195195 }
196196 }
197197}
198+
199+ #[ cfg( test) ]
200+ mod tests {
201+ // Important to start tests with paused clock. If starting unpaused, and calling `tokio::time::pause()`, then
202+ // things like `sleep` will advance the timer not fully accurately (I have no idea why), causing
203+ // few ms added clock advancement at the end of the test.
204+ // Starting paused is done with `#[tokio::test(flavor = "current_thread", start_paused = true)]`.
205+ // Pausing can only be done with current_thread executor.
206+
207+ #[ cfg( feature = "metrics" ) ]
208+ use std:: sync:: Arc ;
209+ use std:: sync:: LazyLock ;
210+ use std:: time:: Duration ;
211+
212+ use assert_matches:: assert_matches;
213+
214+ use crate :: errors:: { RequestAttemptError , RequestError } ;
215+ #[ cfg( feature = "metrics" ) ]
216+ use crate :: observability:: metrics:: Metrics ;
217+ use crate :: policies:: speculative_execution:: { Context , SimpleSpeculativeExecutionPolicy } ;
218+ use crate :: response:: Coordinator ;
219+
220+ static EMPTY_CONTEXT : LazyLock < Context > = LazyLock :: new ( || Context {
221+ #[ cfg( feature = "metrics" ) ]
222+ metrics : Arc :: new ( Metrics :: new ( ) ) ,
223+ } ) ;
224+
225+ static IGNORABLE_ERROR : Option < Result < ( ( ) , Coordinator ) , RequestError > > = Some ( Err (
226+ RequestError :: LastAttemptError ( RequestAttemptError :: UnableToAllocStreamId ) ,
227+ ) ) ;
228+
229+ #[ tokio:: test( flavor = "current_thread" , start_paused = true ) ]
230+ async fn test_exhausted_plan_with_running_fibers ( ) {
231+ let policy = SimpleSpeculativeExecutionPolicy {
232+ max_retry_count : 5 ,
233+ retry_interval : Duration :: from_secs ( 1 ) ,
234+ } ;
235+
236+ let generator = {
237+ // Index of the fiber, 0 for first execution.
238+ let mut counter = 0 ;
239+ move |_first : bool | {
240+ let future = {
241+ let fiber_idx = counter;
242+ async move {
243+ if fiber_idx < 4 {
244+ tokio:: time:: sleep ( Duration :: from_secs ( 5 ) ) . await ;
245+ IGNORABLE_ERROR . clone ( )
246+ } else if fiber_idx == 4 {
247+ None
248+ } else {
249+ panic ! ( "Too many speculative executions - expected 4" ) ;
250+ }
251+ }
252+ } ;
253+ counter += 1 ;
254+ future
255+ }
256+ } ;
257+
258+ let now = tokio:: time:: Instant :: now ( ) ;
259+ let res = super :: execute ( & policy, & EMPTY_CONTEXT , generator) . await ;
260+ assert_matches ! (
261+ res,
262+ Err ( RequestError :: LastAttemptError (
263+ RequestAttemptError :: UnableToAllocStreamId
264+ ) )
265+ ) ;
266+ // t - now
267+ // First execution is started at t
268+ // Speculative executions - at t+1, t+2, t+3, t+4
269+ // The one at t+4 will return first, with None, preventing starting new one at t+5.
270+ // Then execute should wait on spawned fibers. Last one will be the one spawned at t+3, finishing at t+8.
271+ assert_eq ! (
272+ tokio:: time:: Instant :: now( ) ,
273+ now. checked_add( Duration :: from_secs( 8 ) ) . unwrap( )
274+ )
275+ }
276+
277+ #[ tokio:: test( flavor = "current_thread" , start_paused = true ) ]
278+ async fn test_exhausted_plan_last_running_fiber ( ) {
279+ let policy = SimpleSpeculativeExecutionPolicy {
280+ max_retry_count : 5 ,
281+ // Each attempt will finish before next starts
282+ retry_interval : Duration :: from_secs ( 6 ) ,
283+ } ;
284+
285+ let generator = {
286+ // Index of the fiber, 0 for first execution.
287+ let mut counter = 0 ;
288+ move |_first : bool | {
289+ let future = {
290+ let fiber_idx = counter;
291+ async move {
292+ if fiber_idx < 4 {
293+ tokio:: time:: sleep ( Duration :: from_secs ( 5 ) ) . await ;
294+ IGNORABLE_ERROR . clone ( )
295+ } else if fiber_idx == 4 {
296+ None
297+ } else {
298+ panic ! ( "Too many speculative executions - expected 4" ) ;
299+ }
300+ }
301+ } ;
302+ counter += 1 ;
303+ future
304+ }
305+ } ;
306+
307+ let now = tokio:: time:: Instant :: now ( ) ;
308+ let res = super :: execute ( & policy, & EMPTY_CONTEXT , generator) . await ;
309+ assert_matches ! (
310+ res,
311+ Err ( RequestError :: LastAttemptError (
312+ RequestAttemptError :: UnableToAllocStreamId
313+ ) )
314+ ) ;
315+ // t - now
316+ // First execution is started at t
317+ // Speculative executions - at t+6, t+12, t+18, t+24
318+ // Each execution finishes before next starts. The one at t+24 finishes instantly with
319+ // None, so the next one should not be started.
320+ assert_eq ! (
321+ tokio:: time:: Instant :: now( ) ,
322+ now. checked_add( Duration :: from_secs( 24 ) ) . unwrap( )
323+ )
324+ }
325+ }
0 commit comments