@@ -8,6 +8,7 @@ use std::ops::ControlFlow;
88use  std:: pin:: Pin ; 
99use  std:: sync:: Arc ; 
1010use  std:: task:: { Context ,  Poll } ; 
11+ use  std:: time:: Duration ; 
1112
1213use  futures:: Stream ; 
1314use  scylla_cql:: Consistency ; 
@@ -616,6 +617,7 @@ where
616617struct  SingleConnectionPagerWorker < Fetcher >  { 
617618    sender :  ProvingSender < Result < ReceivedPage ,  NextPageError > > , 
618619    fetcher :  Fetcher , 
620+     timeout :  Option < Duration > , 
619621} 
620622
621623impl < Fetcher ,  FetchFut >  SingleConnectionPagerWorker < Fetcher > 
@@ -625,8 +627,8 @@ where
625627{ 
626628    async  fn  work ( mut  self )  -> PageSendAttemptedProof  { 
627629        match  self . do_work ( ) . await  { 
628-             Ok ( proof)  => proof, 
629-             Err ( err)  => { 
630+             Ok ( Ok ( proof) )  => proof, 
631+             Ok ( Err ( err) )  => { 
630632                let  ( proof,  _)  = self 
631633                    . sender 
632634                    . send ( Err ( NextPageError :: RequestFailure ( 
@@ -635,15 +637,46 @@ where
635637                    . await ; 
636638                proof
637639            } 
640+             Err ( RequestTimeoutError ( timeout) )  => { 
641+                 let  ( proof,  _)  = self 
642+                     . sender 
643+                     . send ( Err ( NextPageError :: RequestFailure ( 
644+                         RequestError :: RequestTimeout ( timeout) , 
645+                     ) ) ) 
646+                     . await ; 
647+                 proof
648+             } 
638649        } 
639650    } 
640651
641-     async  fn  do_work ( & mut  self )  -> Result < PageSendAttemptedProof ,  RequestAttemptError >  { 
652+     async  fn  do_work ( 
653+         & mut  self , 
654+     )  -> Result < Result < PageSendAttemptedProof ,  RequestAttemptError > ,  RequestTimeoutError >  { 
642655        let  mut  paging_state = PagingState :: start ( ) ; 
643656        loop  { 
644-             let  response = ( self . fetcher ) ( paging_state) 
645-                 . await 
646-                 . and_then ( QueryResponse :: into_non_error_query_response) ?; 
657+             let  runner = async  { 
658+                 ( self . fetcher ) ( paging_state) 
659+                     . await 
660+                     . and_then ( QueryResponse :: into_non_error_query_response) 
661+             } ; 
662+             let  response_res = match  self . timeout  { 
663+                 Some ( timeout)  => { 
664+                     match  tokio:: time:: timeout ( timeout,  runner) . await  { 
665+                         Ok ( res)  => res, 
666+                         Err ( _)  /* tokio::time::error::Elapsed */  => { 
667+                             return  Err ( RequestTimeoutError ( timeout) ) ; 
668+                         } 
669+                     } 
670+                 } 
671+ 
672+                 None  => runner. await , 
673+             } ; 
674+             let  response = match  response_res { 
675+                 Ok ( resp)  => resp, 
676+                 Err ( err)  => { 
677+                     return  Ok ( Err ( err) ) ; 
678+                 } 
679+             } ; 
647680
648681            match  response. response  { 
649682                NonErrorResponse :: Result ( result:: Result :: Rows ( ( rows,  paging_state_response) ) )  => { 
@@ -658,7 +691,7 @@ where
658691
659692                    if  send_result. is_err ( )  { 
660693                        // channel was closed, QueryPager was dropped - should shutdown 
661-                         return  Ok ( proof) ; 
694+                         return  Ok ( Ok ( proof) ) ; 
662695                    } 
663696
664697                    match  paging_state_response. into_paging_control_flow ( )  { 
@@ -667,7 +700,7 @@ where
667700                        } 
668701                        ControlFlow :: Break ( ( ) )  => { 
669702                            // Reached the last query, shutdown 
670-                             return  Ok ( proof) ; 
703+                             return  Ok ( Ok ( proof) ) ; 
671704                        } 
672705                    } 
673706                } 
@@ -677,12 +710,12 @@ where
677710
678711                    // We must attempt to send something because the iterator expects it. 
679712                    let  ( proof,  _)  = self . sender . send_empty_page ( response. tracing_id ,  None ) . await ; 
680-                     return  Ok ( proof) ; 
713+                     return  Ok ( Ok ( proof) ) ; 
681714                } 
682715                _ => { 
683-                     return  Err ( RequestAttemptError :: UnexpectedResponse ( 
716+                     return  Ok ( Err ( RequestAttemptError :: UnexpectedResponse ( 
684717                        response. response . to_response_kind ( ) , 
685-                     ) ) ; 
718+                     ) ) ) ; 
686719                } 
687720            } 
688721        } 
@@ -1054,6 +1087,12 @@ If you are using this API, you are probably doing something wrong."
10541087        let  ( sender,  receiver)  = mpsc:: channel :: < Result < ReceivedPage ,  NextPageError > > ( 1 ) ; 
10551088
10561089        let  page_size = query. get_validated_page_size ( ) ; 
1090+         let  timeout = query. get_request_timeout ( ) . or_else ( || { 
1091+             query
1092+                 . get_execution_profile_handle ( ) ?
1093+                 . access ( ) 
1094+                 . request_timeout 
1095+         } ) ; 
10571096
10581097        let  worker_task = async  move  { 
10591098            let  worker = SingleConnectionPagerWorker  { 
@@ -1067,6 +1106,7 @@ If you are using this API, you are probably doing something wrong."
10671106                        paging_state, 
10681107                    ) 
10691108                } , 
1109+                 timeout, 
10701110            } ; 
10711111            worker. work ( ) . await 
10721112        } ; 
@@ -1084,6 +1124,12 @@ If you are using this API, you are probably doing something wrong."
10841124        let  ( sender,  receiver)  = mpsc:: channel :: < Result < ReceivedPage ,  NextPageError > > ( 1 ) ; 
10851125
10861126        let  page_size = prepared. get_validated_page_size ( ) ; 
1127+         let  timeout = prepared. get_request_timeout ( ) . or_else ( || { 
1128+             prepared
1129+                 . get_execution_profile_handle ( ) ?
1130+                 . access ( ) 
1131+                 . request_timeout 
1132+         } ) ; 
10871133
10881134        let  worker_task = async  move  { 
10891135            let  worker = SingleConnectionPagerWorker  { 
@@ -1098,6 +1144,7 @@ If you are using this API, you are probably doing something wrong."
10981144                        paging_state, 
10991145                    ) 
11001146                } , 
1147+                 timeout, 
11011148            } ; 
11021149            worker. work ( ) . await 
11031150        } ; 
0 commit comments