Skip to content

Commit c4ee9bf

Browse files
committed
pager: add timeouts to SingleConnectionPagerWorker
The SingleConnectionPagerWorker now supports request timeouts when fetching pages. This is for consistency with the main PagerWorker implementation. Note that for now, no metadata queries specify timeouts, so the default per-Session execution profile's timeout will be used. Nothing prevents us from specifying custom timeouts for metadata queries in the future, thanks to this commit.
1 parent 6a08119 commit c4ee9bf

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

scylla/src/client/pager.rs

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::ops::ControlFlow;
88
use std::pin::Pin;
99
use std::sync::Arc;
1010
use std::task::{Context, Poll};
11+
use std::time::Duration;
1112

1213
use futures::Stream;
1314
use scylla_cql::Consistency;
@@ -616,6 +617,7 @@ where
616617
struct SingleConnectionPagerWorker<Fetcher> {
617618
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
618619
fetcher: Fetcher,
620+
timeout: Option<Duration>,
619621
}
620622

621623
impl<Fetcher, FetchFut> SingleConnectionPagerWorker<Fetcher>
@@ -625,8 +627,8 @@ where
625627
{
626628
async fn work(mut self) -> PageSendAttemptedProof {
627629
match self.do_work().await {
628-
Ok(proof) => proof,
629-
Err(err) => {
630+
Ok(Ok(proof)) => proof,
631+
Ok(Err(err)) => {
630632
let (proof, _) = self
631633
.sender
632634
.send(Err(NextPageError::RequestFailure(
@@ -635,15 +637,46 @@ where
635637
.await;
636638
proof
637639
}
640+
Err(RequestTimeoutError(timeout)) => {
641+
let (proof, _) = self
642+
.sender
643+
.send(Err(NextPageError::RequestFailure(
644+
RequestError::RequestTimeout(timeout),
645+
)))
646+
.await;
647+
proof
648+
}
638649
}
639650
}
640651

641-
async fn do_work(&mut self) -> Result<PageSendAttemptedProof, RequestAttemptError> {
652+
async fn do_work(
653+
&mut self,
654+
) -> Result<Result<PageSendAttemptedProof, RequestAttemptError>, RequestTimeoutError> {
642655
let mut paging_state = PagingState::start();
643656
loop {
644-
let response = (self.fetcher)(paging_state)
645-
.await
646-
.and_then(QueryResponse::into_non_error_query_response)?;
657+
let runner = async {
658+
(self.fetcher)(paging_state)
659+
.await
660+
.and_then(QueryResponse::into_non_error_query_response)
661+
};
662+
let response_res = match self.timeout {
663+
Some(timeout) => {
664+
match tokio::time::timeout(timeout, runner).await {
665+
Ok(res) => res,
666+
Err(_) /* tokio::time::error::Elapsed */ => {
667+
return Err(RequestTimeoutError(timeout));
668+
}
669+
}
670+
}
671+
672+
None => runner.await,
673+
};
674+
let response = match response_res {
675+
Ok(resp) => resp,
676+
Err(err) => {
677+
return Ok(Err(err));
678+
}
679+
};
647680

648681
match response.response {
649682
NonErrorResponse::Result(result::Result::Rows((rows, paging_state_response))) => {
@@ -658,7 +691,7 @@ where
658691

659692
if send_result.is_err() {
660693
// channel was closed, QueryPager was dropped - should shutdown
661-
return Ok(proof);
694+
return Ok(Ok(proof));
662695
}
663696

664697
match paging_state_response.into_paging_control_flow() {
@@ -667,7 +700,7 @@ where
667700
}
668701
ControlFlow::Break(()) => {
669702
// Reached the last query, shutdown
670-
return Ok(proof);
703+
return Ok(Ok(proof));
671704
}
672705
}
673706
}
@@ -677,12 +710,12 @@ where
677710

678711
// We must attempt to send something because the iterator expects it.
679712
let (proof, _) = self.sender.send_empty_page(response.tracing_id, None).await;
680-
return Ok(proof);
713+
return Ok(Ok(proof));
681714
}
682715
_ => {
683-
return Err(RequestAttemptError::UnexpectedResponse(
716+
return Ok(Err(RequestAttemptError::UnexpectedResponse(
684717
response.response.to_response_kind(),
685-
));
718+
)));
686719
}
687720
}
688721
}
@@ -1054,6 +1087,12 @@ If you are using this API, you are probably doing something wrong."
10541087
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
10551088

10561089
let page_size = query.get_validated_page_size();
1090+
let timeout = query.get_request_timeout().or_else(|| {
1091+
query
1092+
.get_execution_profile_handle()?
1093+
.access()
1094+
.request_timeout
1095+
});
10571096

10581097
let worker_task = async move {
10591098
let worker = SingleConnectionPagerWorker {
@@ -1067,6 +1106,7 @@ If you are using this API, you are probably doing something wrong."
10671106
paging_state,
10681107
)
10691108
},
1109+
timeout,
10701110
};
10711111
worker.work().await
10721112
};
@@ -1084,6 +1124,12 @@ If you are using this API, you are probably doing something wrong."
10841124
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
10851125

10861126
let page_size = prepared.get_validated_page_size();
1127+
let timeout = prepared.get_request_timeout().or_else(|| {
1128+
prepared
1129+
.get_execution_profile_handle()?
1130+
.access()
1131+
.request_timeout
1132+
});
10871133

10881134
let worker_task = async move {
10891135
let worker = SingleConnectionPagerWorker {
@@ -1098,6 +1144,7 @@ If you are using this API, you are probably doing something wrong."
10981144
paging_state,
10991145
)
11001146
},
1147+
timeout,
11011148
};
11021149
worker.work().await
11031150
};

0 commit comments

Comments
 (0)