@@ -21,6 +21,7 @@ pub struct Context {
2121
2222/// The policy that decides if the driver will send speculative queries to the
2323/// next targets when the current target takes too long to respond.
24+ // TODO(2.0): Consider renaming the methods to get rid of "retry" naming.
2425pub trait SpeculativeExecutionPolicy : std:: fmt:: Debug + Send + Sync {
2526 /// The maximum number of speculative executions that will be triggered
2627 /// for a given request (does not include the initial request)
@@ -144,7 +145,7 @@ const EMPTY_PLAN_ERROR: RequestError = RequestError::EmptyPlan;
144145pub ( crate ) async fn execute < QueryFut , ResT > (
145146 policy : & dyn SpeculativeExecutionPolicy ,
146147 context : & Context ,
147- query_runner_generator : impl Fn ( bool ) -> QueryFut ,
148+ mut query_runner_generator : impl FnMut ( bool ) -> QueryFut ,
148149) -> Result < ( ResT , Coordinator ) , RequestError >
149150where
150151 QueryFut : Future < Output = Option < Result < ( ResT , Coordinator ) , RequestError > > > ,
@@ -180,6 +181,11 @@ where
180181 } else {
181182 last_error = Some ( r)
182183 }
184+ } else {
185+ // The only case where None is returned is when execution plan was exhausted.
186+ // If so, there is no reason to start any more fibers.
187+ // We can't always return - there may still be fibers running.
188+ retries_remaining = 0 ;
183189 }
184190 if async_tasks. is_empty( ) && retries_remaining == 0 {
185191 return last_error. unwrap_or( {
@@ -190,3 +196,166 @@ where
190196 }
191197 }
192198}
199+
200+ #[ cfg( test) ]
201+ mod tests {
202+ // Important to start tests with paused clock. If starting unpaused, and calling `tokio::time::pause()`, then
203+ // things like `sleep` will advance the timer not fully accurately (I have no idea why), causing
204+ // few ms added clock advancement at the end of the test.
205+ // Starting paused is done with `#[tokio::test(flavor = "current_thread", start_paused = true)]`.
206+ // Pausing can only be done with current_thread executor.
207+
208+ #[ cfg( feature = "metrics" ) ]
209+ use std:: sync:: Arc ;
210+ use std:: sync:: LazyLock ;
211+ use std:: time:: Duration ;
212+
213+ use assert_matches:: assert_matches;
214+
215+ use crate :: errors:: { RequestAttemptError , RequestError } ;
216+ #[ cfg( feature = "metrics" ) ]
217+ use crate :: observability:: metrics:: Metrics ;
218+ use crate :: policies:: speculative_execution:: { Context , SimpleSpeculativeExecutionPolicy } ;
219+ use crate :: response:: Coordinator ;
220+
221+ static EMPTY_CONTEXT : LazyLock < Context > = LazyLock :: new ( || Context {
222+ #[ cfg( feature = "metrics" ) ]
223+ metrics : Arc :: new ( Metrics :: new ( ) ) ,
224+ } ) ;
225+
226+ static IGNORABLE_ERROR : Option < Result < ( ( ) , Coordinator ) , RequestError > > = Some ( Err (
227+ RequestError :: LastAttemptError ( RequestAttemptError :: UnableToAllocStreamId ) ,
228+ ) ) ;
229+
230+ #[ tokio:: test( flavor = "current_thread" , start_paused = true ) ]
231+ async fn test_exhausted_plan_with_running_fibers ( ) {
232+ let policy = SimpleSpeculativeExecutionPolicy {
233+ max_retry_count : 5 ,
234+ retry_interval : Duration :: from_secs ( 1 ) ,
235+ } ;
236+
237+ let generator = {
238+ // Index of the fiber, 0 for first execution.
239+ let mut counter = 0 ;
240+ move |_first : bool | {
241+ let future = {
242+ let fiber_idx = counter;
243+ async move {
244+ if fiber_idx < 4 {
245+ tokio:: time:: sleep ( Duration :: from_secs ( 5 ) ) . await ;
246+ IGNORABLE_ERROR . clone ( )
247+ } else if fiber_idx == 4 {
248+ None
249+ } else {
250+ panic ! ( "Too many speculative executions - expected 4" ) ;
251+ }
252+ }
253+ } ;
254+ counter += 1 ;
255+ future
256+ }
257+ } ;
258+
259+ let now = tokio:: time:: Instant :: now ( ) ;
260+ let res = super :: execute ( & policy, & EMPTY_CONTEXT , generator) . await ;
261+ assert_matches ! (
262+ res,
263+ Err ( RequestError :: LastAttemptError (
264+ RequestAttemptError :: UnableToAllocStreamId
265+ ) )
266+ ) ;
267+ // t - now
268+ // First execution is started at t
269+ // Speculative executions - at t+1, t+2, t+3, t+4
270+ // The one at t+4 will return first, with None, preventing starting new one at t+5.
271+ // Then execute should wait on spawned fibers. Last one will be the one spawned at t+3, finishing at t+8.
272+ assert_eq ! (
273+ tokio:: time:: Instant :: now( ) ,
274+ now. checked_add( Duration :: from_secs( 8 ) ) . unwrap( )
275+ )
276+ }
277+
278+ #[ tokio:: test( flavor = "current_thread" , start_paused = true ) ]
279+ async fn test_exhausted_plan_last_running_fiber ( ) {
280+ let policy = SimpleSpeculativeExecutionPolicy {
281+ max_retry_count : 5 ,
282+ // Each attempt will finish before next starts
283+ retry_interval : Duration :: from_secs ( 6 ) ,
284+ } ;
285+
286+ let generator = {
287+ // Index of the fiber, 0 for first execution.
288+ let mut counter = 0 ;
289+ move |_first : bool | {
290+ let future = {
291+ let fiber_idx = counter;
292+ async move {
293+ if fiber_idx < 4 {
294+ tokio:: time:: sleep ( Duration :: from_secs ( 5 ) ) . await ;
295+ IGNORABLE_ERROR . clone ( )
296+ } else if fiber_idx == 4 {
297+ None
298+ } else {
299+ panic ! ( "Too many speculative executions - expected 4" ) ;
300+ }
301+ }
302+ } ;
303+ counter += 1 ;
304+ future
305+ }
306+ } ;
307+
308+ let now = tokio:: time:: Instant :: now ( ) ;
309+ let res = super :: execute ( & policy, & EMPTY_CONTEXT , generator) . await ;
310+ assert_matches ! (
311+ res,
312+ Err ( RequestError :: LastAttemptError (
313+ RequestAttemptError :: UnableToAllocStreamId
314+ ) )
315+ ) ;
316+ // t - now
317+ // First execution is started at t
318+ // Speculative executions - at t+6, t+12, t+18, t+24
319+ // Each execution finishes before next starts. The one at t+24 finishes instantly with
320+ // None, so the next one should not be started.
321+ assert_eq ! (
322+ tokio:: time:: Instant :: now( ) ,
323+ now. checked_add( Duration :: from_secs( 24 ) ) . unwrap( )
324+ )
325+ }
326+
327+ // Regresion test for https://github.com/scylladb/scylla-rust-driver/issues/1085
328+ #[ tokio:: test( flavor = "current_thread" , start_paused = true ) ]
329+ async fn test_se_panic_on_ignorable_errors ( ) {
330+ let policy = SimpleSpeculativeExecutionPolicy {
331+ max_retry_count : 5 ,
332+ // Each attempt will finish before next starts
333+ retry_interval : Duration :: from_secs ( 1 ) ,
334+ } ;
335+
336+ let generator = {
337+ move |_first : bool | async move {
338+ tokio:: time:: sleep ( Duration :: from_secs ( 5 ) ) . await ;
339+ IGNORABLE_ERROR . clone ( )
340+ }
341+ } ;
342+
343+ let now = tokio:: time:: Instant :: now ( ) ;
344+ let res = super :: execute ( & policy, & EMPTY_CONTEXT , generator) . await ;
345+ assert_matches ! (
346+ res,
347+ Err ( RequestError :: LastAttemptError (
348+ RequestAttemptError :: UnableToAllocStreamId
349+ ) )
350+ ) ;
351+ // t - now
352+ // First execution is started at t
353+ // Speculative executions - at t+1, t+2, t+3, t+4, t+5
354+ // Each execution sleeps 5 seconds and returns ignorable error.
355+ // Last execution should finish at t+10.
356+ assert_eq ! (
357+ tokio:: time:: Instant :: now( ) ,
358+ now. checked_add( Duration :: from_secs( 10 ) ) . unwrap( )
359+ )
360+ }
361+ }
0 commit comments