@@ -8,6 +8,7 @@ use std::ops::ControlFlow;
88use std:: pin:: Pin ;
99use std:: sync:: Arc ;
1010use std:: task:: { Context , Poll } ;
11+ use std:: time:: Duration ;
1112
1213use futures:: Stream ;
1314use scylla_cql:: Consistency ;
@@ -148,6 +149,7 @@ struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
148149 query_is_idempotent : bool ,
149150 query_consistency : Consistency ,
150151 retry_session : Box < dyn RetrySession > ,
152+ timeout : Option < Duration > ,
151153 #[ cfg( feature = "metrics" ) ]
152154 metrics : Arc < Metrics > ,
153155
@@ -178,6 +180,8 @@ where
178180 let mut current_consistency: Consistency = self . query_consistency ;
179181
180182 self . log_request_start ( ) ;
183+ let start_instant = tokio:: time:: Instant :: now ( ) ;
184+ let mut timeout_instant = self . timeout . map ( |t| start_instant + t) ;
181185
182186 ' nodes_in_plan: for ( node, shard) in query_plan {
183187 let span = trace_span ! ( parent: & self . parent_span, "Executing query" , node = %node. address, shard = %shard) ;
@@ -208,27 +212,50 @@ where
208212 Coordinator :: new ( node, node. sharder ( ) . is_some ( ) . then_some ( shard) , & connection) ;
209213
210214 // Query pages until an error occurs
211- let queries_result: Result < PageSendAttemptedProof , RequestAttemptError > = self
212- . query_pages ( & connection, current_consistency, node, coordinator. clone ( ) )
215+ let queries_result: Result <
216+ Result < PageSendAttemptedProof , RequestAttemptError > ,
217+ RequestTimeoutError ,
218+ > = self
219+ . query_pages (
220+ & connection,
221+ current_consistency,
222+ node,
223+ coordinator. clone ( ) ,
224+ timeout_instant. as_mut ( ) ,
225+ )
213226 . instrument ( span. clone ( ) )
214227 . await ;
215228
216229 let request_error: RequestAttemptError = match queries_result {
217- Ok ( proof) => {
230+ Ok ( Ok ( proof) ) => {
218231 trace ! ( parent: & span, "Request succeeded" ) ;
219232 // query_pages returned Ok, so we are guaranteed
220233 // that it attempted to send at least one page
221234 // through self.sender and we can safely return now.
222235 return proof;
223236 }
224- Err ( error) => {
237+ Ok ( Err ( error) ) => {
225238 trace ! (
226239 parent: & span,
227240 error = %error,
228241 "Request failed"
229242 ) ;
230243 error
231244 }
245+ Err ( request_timeout) => {
246+ let request_error = RequestError :: RequestTimeout ( request_timeout. 0 ) ;
247+ self . log_request_error ( & request_error) ;
248+ trace ! (
249+ parent: & span,
250+ error = %request_error,
251+ "Request timed out"
252+ ) ;
253+ let ( proof, _) = self
254+ . sender
255+ . send ( Err ( NextPageError :: RequestFailure ( request_error) ) )
256+ . await ;
257+ return proof;
258+ }
232259 } ;
233260
234261 // Use retry policy to decide what to do next
@@ -269,7 +296,7 @@ where
269296 // Although we are in an awkward situation (_iter
270297 // interface isn't meant for sending writes),
271298 // we must attempt to send something because
272- // the iterator expects it.
299+ // QueryPager expects it.
273300 let ( proof, _) = self
274301 . sender
275302 . send_empty_page ( None , Some ( coordinator. clone ( ) ) )
@@ -299,7 +326,8 @@ where
299326 consistency : Consistency ,
300327 node : NodeRef < ' _ > ,
301328 coordinator : Coordinator ,
302- ) -> Result < PageSendAttemptedProof , RequestAttemptError > {
329+ mut timeout_instant : Option < & mut tokio:: time:: Instant > ,
330+ ) -> Result < Result < PageSendAttemptedProof , RequestAttemptError > , RequestTimeoutError > {
303331 loop {
304332 let request_span = ( self . span_creator ) ( ) ;
305333 match self
@@ -309,12 +337,31 @@ where
309337 node,
310338 coordinator. clone ( ) ,
311339 & request_span,
340+ timeout_instant. as_ref ( ) . map ( |instant| * * instant) ,
312341 )
313342 . instrument ( request_span. span ( ) . clone ( ) )
314- . await ?
343+ . await
315344 {
316- ControlFlow :: Break ( proof) => return Ok ( proof) ,
317- ControlFlow :: Continue ( _) => { }
345+ Ok ( Ok ( ControlFlow :: Break ( proof) ) ) => {
346+ // Successfully queried the last remaining page.
347+ return Ok ( Ok ( proof) ) ;
348+ }
349+
350+ Ok ( Ok ( ControlFlow :: Continue ( ( ) ) ) ) => {
351+ // Successfully queried one page, and there are more to fetch.
352+ // Reset the timeout_instant for the next page fetch.
353+ if let Some ( timeout) = self . timeout
354+ && let Some ( ref mut instant) = timeout_instant
355+ {
356+ * * instant = tokio:: time:: Instant :: now ( ) + timeout;
357+ }
358+ }
359+ Ok ( Err ( request_attempt_error) ) => {
360+ return Ok ( Err ( request_attempt_error) ) ;
361+ }
362+ Err ( request_timeout_error) => {
363+ return Err ( request_timeout_error) ;
364+ }
318365 }
319366 }
320367 }
@@ -326,7 +373,11 @@ where
326373 node : NodeRef < ' _ > ,
327374 coordinator : Coordinator ,
328375 request_span : & RequestSpan ,
329- ) -> Result < ControlFlow < PageSendAttemptedProof , ( ) > , RequestAttemptError > {
376+ timeout_instant : Option < tokio:: time:: Instant > ,
377+ ) -> Result <
378+ Result < ControlFlow < PageSendAttemptedProof , ( ) > , RequestAttemptError > ,
379+ RequestTimeoutError ,
380+ > {
330381 #[ cfg( feature = "metrics" ) ]
331382 self . metrics . inc_total_paged_queries ( ) ;
332383 let query_start = std:: time:: Instant :: now ( ) ;
@@ -338,10 +389,26 @@ where
338389 ) ;
339390 self . log_attempt_start ( connect_address) ;
340391
341- let query_response =
392+ let runner = async {
342393 ( self . page_query ) ( connection. clone ( ) , consistency, self . paging_state . clone ( ) )
343394 . await
344- . and_then ( QueryResponse :: into_non_error_query_response) ;
395+ . and_then ( QueryResponse :: into_non_error_query_response)
396+ } ;
397+ let query_response = match ( self . timeout , timeout_instant) {
398+ ( Some ( timeout) , Some ( instant) ) => {
399+ match tokio:: time:: timeout_at ( instant, runner) . await {
400+ Ok ( res) => res,
401+ Err ( _) /* tokio::time::error::Elapsed */ => {
402+ #[ cfg( feature = "metrics" ) ]
403+ self . metrics . inc_request_timeouts ( ) ;
404+ return Err ( RequestTimeoutError ( timeout) ) ;
405+ }
406+ }
407+ }
408+
409+ ( None , None ) => runner. await ,
410+ _ => unreachable ! ( "timeout_instant must be Some iff self.timeout is Some" ) ,
411+ } ;
345412
346413 let elapsed = query_start. elapsed ( ) ;
347414
@@ -373,7 +440,7 @@ where
373440 let ( proof, res) = self . sender . send ( Ok ( received_page) ) . await ;
374441 if res. is_err ( ) {
375442 // channel was closed, QueryPager was dropped - should shutdown
376- return Ok ( ControlFlow :: Break ( proof) ) ;
443+ return Ok ( Ok ( ControlFlow :: Break ( proof) ) ) ;
377444 }
378445
379446 match paging_state_response. into_paging_control_flow ( ) {
@@ -382,15 +449,15 @@ where
382449 }
383450 ControlFlow :: Break ( ( ) ) => {
384451 // Reached the last query, shutdown
385- return Ok ( ControlFlow :: Break ( proof) ) ;
452+ return Ok ( Ok ( ControlFlow :: Break ( proof) ) ) ;
386453 }
387454 }
388455
389456 // Query succeeded, reset retry policy for future retries
390457 self . retry_session . reset ( ) ;
391458 self . log_request_start ( ) ;
392459
393- Ok ( ControlFlow :: Continue ( ( ) ) )
460+ Ok ( Ok ( ControlFlow :: Continue ( ( ) ) ) )
394461 }
395462 Err ( err) => {
396463 #[ cfg( feature = "metrics" ) ]
@@ -401,7 +468,7 @@ where
401468 node,
402469 & err,
403470 ) ;
404- Err ( err)
471+ Ok ( Err ( err) )
405472 }
406473 Ok ( NonErrorQueryResponse {
407474 response : NonErrorResponse :: Result ( _) ,
@@ -416,7 +483,7 @@ where
416483 . sender
417484 . send_empty_page ( tracing_id, Some ( coordinator) )
418485 . await ;
419- Ok ( ControlFlow :: Break ( proof) )
486+ Ok ( Ok ( ControlFlow :: Break ( proof) ) )
420487 }
421488 Ok ( response) => {
422489 #[ cfg( feature = "metrics" ) ]
@@ -429,7 +496,7 @@ where
429496 node,
430497 & err,
431498 ) ;
432- Err ( err)
499+ Ok ( Err ( err) )
433500 }
434501 }
435502 }
@@ -735,6 +802,10 @@ If you are using this API, you are probably doing something wrong."
735802 . serial_consistency
736803 . unwrap_or ( execution_profile. serial_consistency ) ;
737804
805+ let timeout = statement
806+ . get_request_timeout ( )
807+ . or ( execution_profile. request_timeout ) ;
808+
738809 let page_size = statement. get_validated_page_size ( ) ;
739810
740811 let routing_info = RoutingInfo {
@@ -791,6 +862,7 @@ If you are using this API, you are probably doing something wrong."
791862 query_consistency : consistency,
792863 load_balancing_policy,
793864 retry_session,
865+ timeout,
794866 #[ cfg( feature = "metrics" ) ]
795867 metrics,
796868 paging_state : PagingState :: start ( ) ,
@@ -823,6 +895,11 @@ If you are using this API, you are probably doing something wrong."
823895 . serial_consistency
824896 . unwrap_or ( config. execution_profile . serial_consistency ) ;
825897
898+ let timeout = config
899+ . prepared
900+ . get_request_timeout ( )
901+ . or ( config. execution_profile . request_timeout ) ;
902+
826903 let page_size = config. prepared . get_validated_page_size ( ) ;
827904
828905 let load_balancing_policy = Arc :: clone (
@@ -919,6 +996,7 @@ If you are using this API, you are probably doing something wrong."
919996 query_consistency : consistency,
920997 load_balancing_policy,
921998 retry_session,
999+ timeout,
9221000 #[ cfg( feature = "metrics" ) ]
9231001 metrics : config. metrics ,
9241002 paging_state : PagingState :: start ( ) ,
@@ -1138,6 +1216,14 @@ where
11381216 }
11391217}
11401218
1219+ /// Failed to run a request within a provided client timeout.
1220+ #[ derive( Error , Debug , Clone ) ]
1221+ #[ error(
1222+ "Request execution exceeded a client timeout of {}ms" ,
1223+ std:: time:: Duration :: as_millis( . 0 )
1224+ ) ]
1225+ struct RequestTimeoutError ( std:: time:: Duration ) ;
1226+
11411227/// An error returned that occurred during next page fetch.
11421228#[ derive( Error , Debug , Clone ) ]
11431229#[ non_exhaustive]
0 commit comments