@@ -8,7 +8,6 @@ use std::ops::ControlFlow;
88use std:: pin:: Pin ;
99use std:: sync:: Arc ;
1010use std:: task:: { Context , Poll } ;
11- use std:: time:: Duration ;
1211
1312use futures:: Stream ;
1413use scylla_cql:: Consistency ;
@@ -135,6 +134,50 @@ use crate::response::Coordinator;
135134
136135type PageSendAttemptedProof = SendAttemptedProof < Result < ReceivedPage , NextPageError > > ;
137136
137+ mod timeouter {
138+ use std:: time:: Duration ;
139+
140+ use tokio:: time:: Instant ;
141+
142+ /// Encapsulation of a timeout for paging queries.
143+ pub ( super ) struct PageQueryTimeouter {
144+ timeout : Duration ,
145+ timeout_instant : Instant ,
146+ }
147+
148+ impl PageQueryTimeouter {
149+ /// Creates a new PageQueryTimeouter with the given timeout duration,
150+ /// starting from now.
151+ pub ( super ) fn new ( timeout : Duration ) -> Self {
152+ Self {
153+ timeout,
154+ timeout_instant : Instant :: now ( ) + timeout,
155+ }
156+ }
157+
158+ /// Returns the timeout duration.
159+ pub ( super ) fn timeout_duration ( & self ) -> Duration {
160+ self . timeout
161+ }
162+
163+ /// Returns the instant at which the timeout will elapse.
164+ ///
165+ /// This can be used with `tokio::time::timeout_at`.
166+ pub ( super ) fn deadline ( & self ) -> Instant {
167+ self . timeout_instant
168+ }
169+
170+ /// Resets the timeout countdown.
171+ ///
172+ /// This should be called right before beginning first page fetch
173+ /// and after each successful page fetch.
174+ pub ( super ) fn reset ( & mut self ) {
175+ self . timeout_instant = Instant :: now ( ) + self . timeout ;
176+ }
177+ }
178+ }
179+ use timeouter:: PageQueryTimeouter ;
180+
138181// PagerWorker works in the background to fetch pages
139182// QueryPager receives them through a channel
140183struct PagerWorker < ' a , QueryFunc , SpanCreatorFunc > {
@@ -149,7 +192,7 @@ struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
149192 query_is_idempotent : bool ,
150193 query_consistency : Consistency ,
151194 retry_session : Box < dyn RetrySession > ,
152- timeout : Option < Duration > ,
195+ timeouter : Option < PageQueryTimeouter > ,
153196 #[ cfg( feature = "metrics" ) ]
154197 metrics : Arc < Metrics > ,
155198
@@ -180,8 +223,7 @@ where
180223 let mut current_consistency: Consistency = self . query_consistency ;
181224
182225 self . log_request_start ( ) ;
183- let start_instant = tokio:: time:: Instant :: now ( ) ;
184- let mut timeout_instant = self . timeout . map ( |t| start_instant + t) ;
226+ self . timeouter . as_mut ( ) . map ( PageQueryTimeouter :: reset) ;
185227
186228 ' nodes_in_plan: for ( node, shard) in query_plan {
187229 let span = trace_span ! ( parent: & self . parent_span, "Executing query" , node = %node. address, shard = %shard) ;
@@ -216,13 +258,7 @@ where
216258 Result < PageSendAttemptedProof , RequestAttemptError > ,
217259 RequestTimeoutError ,
218260 > = self
219- . query_pages (
220- & connection,
221- current_consistency,
222- node,
223- coordinator. clone ( ) ,
224- timeout_instant. as_mut ( ) ,
225- )
261+ . query_pages ( & connection, current_consistency, node, coordinator. clone ( ) )
226262 . instrument ( span. clone ( ) )
227263 . await ;
228264
@@ -326,7 +362,6 @@ where
326362 consistency : Consistency ,
327363 node : NodeRef < ' _ > ,
328364 coordinator : Coordinator ,
329- mut timeout_instant : Option < & mut tokio:: time:: Instant > ,
330365 ) -> Result < Result < PageSendAttemptedProof , RequestAttemptError > , RequestTimeoutError > {
331366 loop {
332367 let request_span = ( self . span_creator ) ( ) ;
@@ -337,7 +372,6 @@ where
337372 node,
338373 coordinator. clone ( ) ,
339374 & request_span,
340- timeout_instant. as_ref ( ) . map ( |instant| * * instant) ,
341375 )
342376 . instrument ( request_span. span ( ) . clone ( ) )
343377 . await
@@ -350,11 +384,7 @@ where
350384 Ok ( Ok ( ControlFlow :: Continue ( ( ) ) ) ) => {
351385 // Successfully queried one page, and there are more to fetch.
352386 // Reset the timeout_instant for the next page fetch.
353- if let Some ( timeout) = self . timeout
354- && let Some ( ref mut instant) = timeout_instant
355- {
356- * * instant = tokio:: time:: Instant :: now ( ) + timeout;
357- }
387+ self . timeouter . as_mut ( ) . map ( PageQueryTimeouter :: reset) ;
358388 }
359389 Ok ( Err ( request_attempt_error) ) => {
360390 return Ok ( Err ( request_attempt_error) ) ;
@@ -373,7 +403,6 @@ where
373403 node : NodeRef < ' _ > ,
374404 coordinator : Coordinator ,
375405 request_span : & RequestSpan ,
376- timeout_instant : Option < tokio:: time:: Instant > ,
377406 ) -> Result <
378407 Result < ControlFlow < PageSendAttemptedProof , ( ) > , RequestAttemptError > ,
379408 RequestTimeoutError ,
@@ -394,20 +423,19 @@ where
394423 . await
395424 . and_then ( QueryResponse :: into_non_error_query_response)
396425 } ;
397- let query_response = match ( self . timeout , timeout_instant) {
398- ( Some ( timeout) , Some ( instant) ) => {
399- match tokio:: time:: timeout_at ( instant, runner) . await {
400- Ok ( res) => res,
401- Err ( _) /* tokio::time::error::Elapsed */ => {
402- #[ cfg( feature = "metrics" ) ]
403- self . metrics . inc_request_timeouts ( ) ;
404- return Err ( RequestTimeoutError ( timeout) ) ;
426+ let query_response = match self . timeouter {
427+ Some ( ref timeouter) => {
428+ match tokio:: time:: timeout_at ( timeouter. deadline ( ) , runner) . await {
429+ Ok ( res) => res,
430+ Err ( _) /* tokio::time::error::Elapsed */ => {
431+ #[ cfg( feature = "metrics" ) ]
432+ self . metrics . inc_request_timeouts ( ) ;
433+ return Err ( RequestTimeoutError ( timeouter. timeout_duration ( ) ) ) ;
434+ }
405435 }
406436 }
407- }
408437
409- ( None , None ) => runner. await ,
410- _ => unreachable ! ( "timeout_instant must be Some iff self.timeout is Some" ) ,
438+ None => runner. await ,
411439 } ;
412440
413441 let elapsed = query_start. elapsed ( ) ;
@@ -802,9 +830,10 @@ If you are using this API, you are probably doing something wrong."
802830 . serial_consistency
803831 . unwrap_or ( execution_profile. serial_consistency ) ;
804832
805- let timeout = statement
833+ let timeouter = statement
806834 . get_request_timeout ( )
807- . or ( execution_profile. request_timeout ) ;
835+ . or ( execution_profile. request_timeout )
836+ . map ( PageQueryTimeouter :: new) ;
808837
809838 let page_size = statement. get_validated_page_size ( ) ;
810839
@@ -862,7 +891,7 @@ If you are using this API, you are probably doing something wrong."
862891 query_consistency : consistency,
863892 load_balancing_policy,
864893 retry_session,
865- timeout ,
894+ timeouter ,
866895 #[ cfg( feature = "metrics" ) ]
867896 metrics,
868897 paging_state : PagingState :: start ( ) ,
@@ -895,10 +924,11 @@ If you are using this API, you are probably doing something wrong."
895924 . serial_consistency
896925 . unwrap_or ( config. execution_profile . serial_consistency ) ;
897926
898- let timeout = config
927+ let timeouter = config
899928 . prepared
900929 . get_request_timeout ( )
901- . or ( config. execution_profile . request_timeout ) ;
930+ . or ( config. execution_profile . request_timeout )
931+ . map ( PageQueryTimeouter :: new) ;
902932
903933 let page_size = config. prepared . get_validated_page_size ( ) ;
904934
@@ -996,7 +1026,7 @@ If you are using this API, you are probably doing something wrong."
9961026 query_consistency : consistency,
9971027 load_balancing_policy,
9981028 retry_session,
999- timeout ,
1029+ timeouter ,
10001030 #[ cfg( feature = "metrics" ) ]
10011031 metrics : config. metrics ,
10021032 paging_state : PagingState :: start ( ) ,
0 commit comments