Skip to content

Commit 58cb688

Browse files
committed
pager: add timeouts to SingleConnectionPagerWorker
The SingleConnectionPagerWorker now supports request timeouts when fetching pages. This is for consistency with the main PagerWorker implementation. Note that for now, no metadata queries specify timeouts, so the default per-Session execution profile's timeout will be used. Nothing prevents us from specifying custom timeouts for metadata queries in the future, thanks to this commit.
1 parent 617ea92 commit 58cb688

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

scylla/src/client/pager.rs

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::ops::ControlFlow;
88
use std::pin::Pin;
99
use std::sync::Arc;
1010
use std::task::{Context, Poll};
11+
use std::time::Duration;
1112

1213
use futures::Stream;
1314
use scylla_cql::Consistency;
@@ -618,6 +619,7 @@ where
618619
struct SingleConnectionPagerWorker<Fetcher> {
619620
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
620621
fetcher: Fetcher,
622+
timeout: Option<Duration>,
621623
}
622624

623625
impl<Fetcher, FetchFut> SingleConnectionPagerWorker<Fetcher>
@@ -627,8 +629,8 @@ where
627629
{
628630
async fn work(mut self) -> PageSendAttemptedProof {
629631
match self.do_work().await {
630-
Ok(proof) => proof,
631-
Err(err) => {
632+
Ok(Ok(proof)) => proof,
633+
Ok(Err(err)) => {
632634
let (proof, _) = self
633635
.sender
634636
.send(Err(NextPageError::RequestFailure(
@@ -637,15 +639,46 @@ where
637639
.await;
638640
proof
639641
}
642+
Err(RequestTimeoutError(timeout)) => {
643+
let (proof, _) = self
644+
.sender
645+
.send(Err(NextPageError::RequestFailure(
646+
RequestError::RequestTimeout(timeout),
647+
)))
648+
.await;
649+
proof
650+
}
640651
}
641652
}
642653

643-
async fn do_work(&mut self) -> Result<PageSendAttemptedProof, RequestAttemptError> {
654+
async fn do_work(
655+
&mut self,
656+
) -> Result<Result<PageSendAttemptedProof, RequestAttemptError>, RequestTimeoutError> {
644657
let mut paging_state = PagingState::start();
645658
loop {
646-
let response = (self.fetcher)(paging_state)
647-
.await
648-
.and_then(QueryResponse::into_non_error_query_response)?;
659+
let runner = async {
660+
(self.fetcher)(paging_state)
661+
.await
662+
.and_then(QueryResponse::into_non_error_query_response)
663+
};
664+
let response_res = match self.timeout {
665+
Some(timeout) => {
666+
match tokio::time::timeout(timeout, runner).await {
667+
Ok(res) => res,
668+
Err(_) /* tokio::time::error::Elapsed */ => {
669+
return Err(RequestTimeoutError(timeout));
670+
}
671+
}
672+
}
673+
674+
None => runner.await,
675+
};
676+
let response = match response_res {
677+
Ok(resp) => resp,
678+
Err(err) => {
679+
return Ok(Err(err));
680+
}
681+
};
649682

650683
match response.response {
651684
NonErrorResponseWithDeserializedMetadata::Result(
@@ -662,7 +695,7 @@ where
662695

663696
if send_result.is_err() {
664697
// channel was closed, QueryPager was dropped - should shutdown
665-
return Ok(proof);
698+
return Ok(Ok(proof));
666699
}
667700

668701
match paging_state_response.into_paging_control_flow() {
@@ -671,7 +704,7 @@ where
671704
}
672705
ControlFlow::Break(()) => {
673706
// Reached the last query, shutdown
674-
return Ok(proof);
707+
return Ok(Ok(proof));
675708
}
676709
}
677710
}
@@ -681,12 +714,12 @@ where
681714

682715
// We must attempt to send something because the iterator expects it.
683716
let (proof, _) = self.sender.send_empty_page(response.tracing_id, None).await;
684-
return Ok(proof);
717+
return Ok(Ok(proof));
685718
}
686719
_ => {
687-
return Err(RequestAttemptError::UnexpectedResponse(
720+
return Ok(Err(RequestAttemptError::UnexpectedResponse(
688721
response.response.to_response_kind(),
689-
));
722+
)));
690723
}
691724
}
692725
}
@@ -1054,6 +1087,12 @@ If you are using this API, you are probably doing something wrong."
10541087
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
10551088

10561089
let page_size = query.get_validated_page_size();
1090+
let timeout = query.get_request_timeout().or_else(|| {
1091+
query
1092+
.get_execution_profile_handle()?
1093+
.access()
1094+
.request_timeout
1095+
});
10571096

10581097
let worker_task = async move {
10591098
let worker = SingleConnectionPagerWorker {
@@ -1067,6 +1106,7 @@ If you are using this API, you are probably doing something wrong."
10671106
paging_state,
10681107
)
10691108
},
1109+
timeout,
10701110
};
10711111
worker.work().await
10721112
};
@@ -1084,6 +1124,12 @@ If you are using this API, you are probably doing something wrong."
10841124
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
10851125

10861126
let page_size = prepared.get_validated_page_size();
1127+
let timeout = prepared.get_request_timeout().or_else(|| {
1128+
prepared
1129+
.get_execution_profile_handle()?
1130+
.access()
1131+
.request_timeout
1132+
});
10871133

10881134
let worker_task = async move {
10891135
let worker = SingleConnectionPagerWorker {
@@ -1098,6 +1144,7 @@ If you are using this API, you are probably doing something wrong."
10981144
paging_state,
10991145
)
11001146
},
1147+
timeout,
11011148
};
11021149
worker.work().await
11031150
};

0 commit comments

Comments
 (0)