@@ -8,6 +8,7 @@ use std::ops::ControlFlow;
88use std:: pin:: Pin ;
99use std:: sync:: Arc ;
1010use std:: task:: { Context , Poll } ;
11+ use std:: time:: Duration ;
1112
1213use futures:: Stream ;
1314use scylla_cql:: Consistency ;
@@ -618,6 +619,7 @@ where
618619struct SingleConnectionPagerWorker < Fetcher > {
619620 sender : ProvingSender < Result < ReceivedPage , NextPageError > > ,
620621 fetcher : Fetcher ,
622+ timeout : Option < Duration > ,
621623}
622624
623625impl < Fetcher , FetchFut > SingleConnectionPagerWorker < Fetcher >
@@ -627,8 +629,8 @@ where
627629{
628630 async fn work ( mut self ) -> PageSendAttemptedProof {
629631 match self . do_work ( ) . await {
630- Ok ( proof) => proof,
631- Err ( err) => {
632+ Ok ( Ok ( proof) ) => proof,
633+ Ok ( Err ( err) ) => {
632634 let ( proof, _) = self
633635 . sender
634636 . send ( Err ( NextPageError :: RequestFailure (
@@ -637,15 +639,46 @@ where
637639 . await ;
638640 proof
639641 }
642+ Err ( RequestTimeoutError ( timeout) ) => {
643+ let ( proof, _) = self
644+ . sender
645+ . send ( Err ( NextPageError :: RequestFailure (
646+ RequestError :: RequestTimeout ( timeout) ,
647+ ) ) )
648+ . await ;
649+ proof
650+ }
640651 }
641652 }
642653
643- async fn do_work ( & mut self ) -> Result < PageSendAttemptedProof , RequestAttemptError > {
654+ async fn do_work (
655+ & mut self ,
656+ ) -> Result < Result < PageSendAttemptedProof , RequestAttemptError > , RequestTimeoutError > {
644657 let mut paging_state = PagingState :: start ( ) ;
645658 loop {
646- let response = ( self . fetcher ) ( paging_state)
647- . await
648- . and_then ( QueryResponse :: into_non_error_query_response) ?;
659+ let runner = async {
660+ ( self . fetcher ) ( paging_state)
661+ . await
662+ . and_then ( QueryResponse :: into_non_error_query_response)
663+ } ;
664+ let response_res = match self . timeout {
665+ Some ( timeout) => {
666+ match tokio:: time:: timeout ( timeout, runner) . await {
667+ Ok ( res) => res,
668+ Err ( _) /* tokio::time::error::Elapsed */ => {
669+ return Err ( RequestTimeoutError ( timeout) ) ;
670+ }
671+ }
672+ }
673+
674+ None => runner. await ,
675+ } ;
676+ let response = match response_res {
677+ Ok ( resp) => resp,
678+ Err ( err) => {
679+ return Ok ( Err ( err) ) ;
680+ }
681+ } ;
649682
650683 match response. response {
651684 NonErrorResponseWithDeserializedMetadata :: Result (
@@ -662,7 +695,7 @@ where
662695
663696 if send_result. is_err ( ) {
664697 // channel was closed, QueryPager was dropped - should shutdown
665- return Ok ( proof) ;
698+ return Ok ( Ok ( proof) ) ;
666699 }
667700
668701 match paging_state_response. into_paging_control_flow ( ) {
@@ -671,7 +704,7 @@ where
671704 }
672705 ControlFlow :: Break ( ( ) ) => {
673706 // Reached the last query, shutdown
674- return Ok ( proof) ;
707+ return Ok ( Ok ( proof) ) ;
675708 }
676709 }
677710 }
@@ -681,12 +714,12 @@ where
681714
682715 // We must attempt to send something because the iterator expects it.
683716 let ( proof, _) = self . sender . send_empty_page ( response. tracing_id , None ) . await ;
684- return Ok ( proof) ;
717+ return Ok ( Ok ( proof) ) ;
685718 }
686719 _ => {
687- return Err ( RequestAttemptError :: UnexpectedResponse (
720+ return Ok ( Err ( RequestAttemptError :: UnexpectedResponse (
688721 response. response . to_response_kind ( ) ,
689- ) ) ;
722+ ) ) ) ;
690723 }
691724 }
692725 }
@@ -1054,6 +1087,12 @@ If you are using this API, you are probably doing something wrong."
10541087 let ( sender, receiver) = mpsc:: channel :: < Result < ReceivedPage , NextPageError > > ( 1 ) ;
10551088
10561089 let page_size = query. get_validated_page_size ( ) ;
1090+ let timeout = query. get_request_timeout ( ) . or_else ( || {
1091+ query
1092+ . get_execution_profile_handle ( ) ?
1093+ . access ( )
1094+ . request_timeout
1095+ } ) ;
10571096
10581097 let worker_task = async move {
10591098 let worker = SingleConnectionPagerWorker {
@@ -1067,6 +1106,7 @@ If you are using this API, you are probably doing something wrong."
10671106 paging_state,
10681107 )
10691108 } ,
1109+ timeout,
10701110 } ;
10711111 worker. work ( ) . await
10721112 } ;
@@ -1084,6 +1124,12 @@ If you are using this API, you are probably doing something wrong."
10841124 let ( sender, receiver) = mpsc:: channel :: < Result < ReceivedPage , NextPageError > > ( 1 ) ;
10851125
10861126 let page_size = prepared. get_validated_page_size ( ) ;
1127+ let timeout = prepared. get_request_timeout ( ) . or_else ( || {
1128+ prepared
1129+ . get_execution_profile_handle ( ) ?
1130+ . access ( )
1131+ . request_timeout
1132+ } ) ;
10871133
10881134 let worker_task = async move {
10891135 let worker = SingleConnectionPagerWorker {
@@ -1098,6 +1144,7 @@ If you are using this API, you are probably doing something wrong."
10981144 paging_state,
10991145 )
11001146 } ,
1147+ timeout,
11011148 } ;
11021149 worker. work ( ) . await
11031150 } ;
0 commit comments