Skip to content

Commit 4e0693e

Browse files
committed
pager: encapsulate timeout logic
PageQueryTimeouter is introduced to encapsulate timeout logic for paging queries. It was clear to me that such encapsulation in a new type improves code readability and maintainability, as the timeout logic is now clearly separated from the rest of the paging logic.
1 parent 809c7d1 commit 4e0693e

File tree

1 file changed

+66
-36
lines changed

1 file changed

+66
-36
lines changed

scylla/src/client/pager.rs

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use std::ops::ControlFlow;
88
use std::pin::Pin;
99
use std::sync::Arc;
1010
use std::task::{Context, Poll};
11-
use std::time::Duration;
1211

1312
use futures::Stream;
1413
use scylla_cql::Consistency;
@@ -135,6 +134,50 @@ use crate::response::Coordinator;
135134

136135
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, NextPageError>>;
137136

137+
mod timeouter {
138+
use std::time::Duration;
139+
140+
use tokio::time::Instant;
141+
142+
/// Encapsulation of a timeout for paging queries.
143+
pub(super) struct PageQueryTimeouter {
144+
timeout: Duration,
145+
timeout_instant: Instant,
146+
}
147+
148+
impl PageQueryTimeouter {
149+
/// Creates a new PageQueryTimeouter with the given timeout duration,
150+
/// starting from now.
151+
pub(super) fn new(timeout: Duration) -> Self {
152+
Self {
153+
timeout,
154+
timeout_instant: Instant::now() + timeout,
155+
}
156+
}
157+
158+
/// Returns the timeout duration.
159+
pub(super) fn timeout_duration(&self) -> Duration {
160+
self.timeout
161+
}
162+
163+
/// Returns the instant at which the timeout will elapse.
164+
///
165+
/// This can be used with `tokio::time::timeout_at`.
166+
pub(super) fn deadline(&self) -> Instant {
167+
self.timeout_instant
168+
}
169+
170+
/// Resets the timeout countdown.
171+
///
172+
/// This should be called right before beginning first page fetch
173+
/// and after each successful page fetch.
174+
pub(super) fn reset(&mut self) {
175+
self.timeout_instant = Instant::now() + self.timeout;
176+
}
177+
}
178+
}
179+
use timeouter::PageQueryTimeouter;
180+
138181
// PagerWorker works in the background to fetch pages
139182
// QueryPager receives them through a channel
140183
struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
@@ -149,7 +192,7 @@ struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
149192
query_is_idempotent: bool,
150193
query_consistency: Consistency,
151194
retry_session: Box<dyn RetrySession>,
152-
timeout: Option<Duration>,
195+
timeouter: Option<PageQueryTimeouter>,
153196
#[cfg(feature = "metrics")]
154197
metrics: Arc<Metrics>,
155198

@@ -180,8 +223,7 @@ where
180223
let mut current_consistency: Consistency = self.query_consistency;
181224

182225
self.log_request_start();
183-
let start_instant = tokio::time::Instant::now();
184-
let mut timeout_instant = self.timeout.map(|t| start_instant + t);
226+
self.timeouter.as_mut().map(PageQueryTimeouter::reset);
185227

186228
'nodes_in_plan: for (node, shard) in query_plan {
187229
let span = trace_span!(parent: &self.parent_span, "Executing query", node = %node.address, shard = %shard);
@@ -216,13 +258,7 @@ where
216258
Result<PageSendAttemptedProof, RequestAttemptError>,
217259
RequestTimeoutError,
218260
> = self
219-
.query_pages(
220-
&connection,
221-
current_consistency,
222-
node,
223-
coordinator.clone(),
224-
timeout_instant.as_mut(),
225-
)
261+
.query_pages(&connection, current_consistency, node, coordinator.clone())
226262
.instrument(span.clone())
227263
.await;
228264

@@ -326,7 +362,6 @@ where
326362
consistency: Consistency,
327363
node: NodeRef<'_>,
328364
coordinator: Coordinator,
329-
mut timeout_instant: Option<&mut tokio::time::Instant>,
330365
) -> Result<Result<PageSendAttemptedProof, RequestAttemptError>, RequestTimeoutError> {
331366
loop {
332367
let request_span = (self.span_creator)();
@@ -337,7 +372,6 @@ where
337372
node,
338373
coordinator.clone(),
339374
&request_span,
340-
timeout_instant.as_ref().map(|instant| **instant),
341375
)
342376
.instrument(request_span.span().clone())
343377
.await
@@ -350,11 +384,7 @@ where
350384
Ok(Ok(ControlFlow::Continue(()))) => {
351385
// Successfully queried one page, and there are more to fetch.
352386
// Reset the timeout_instant for the next page fetch.
353-
if let Some(timeout) = self.timeout
354-
&& let Some(ref mut instant) = timeout_instant
355-
{
356-
**instant = tokio::time::Instant::now() + timeout;
357-
}
387+
self.timeouter.as_mut().map(PageQueryTimeouter::reset);
358388
}
359389
Ok(Err(request_attempt_error)) => {
360390
return Ok(Err(request_attempt_error));
@@ -373,7 +403,6 @@ where
373403
node: NodeRef<'_>,
374404
coordinator: Coordinator,
375405
request_span: &RequestSpan,
376-
timeout_instant: Option<tokio::time::Instant>,
377406
) -> Result<
378407
Result<ControlFlow<PageSendAttemptedProof, ()>, RequestAttemptError>,
379408
RequestTimeoutError,
@@ -394,20 +423,19 @@ where
394423
.await
395424
.and_then(QueryResponse::into_non_error_query_response)
396425
};
397-
let query_response = match (self.timeout, timeout_instant) {
398-
(Some(timeout), Some(instant)) => {
399-
match tokio::time::timeout_at(instant, runner).await {
400-
Ok(res) => res,
401-
Err(_) /* tokio::time::error::Elapsed */ => {
402-
#[cfg(feature = "metrics")]
403-
self.metrics.inc_request_timeouts();
404-
return Err(RequestTimeoutError(timeout));
426+
let query_response = match self.timeouter {
427+
Some(ref timeouter) => {
428+
match tokio::time::timeout_at(timeouter.deadline(), runner).await {
429+
Ok(res) => res,
430+
Err(_) /* tokio::time::error::Elapsed */ => {
431+
#[cfg(feature = "metrics")]
432+
self.metrics.inc_request_timeouts();
433+
return Err(RequestTimeoutError(timeouter.timeout_duration()));
434+
}
405435
}
406436
}
407-
}
408437

409-
(None, None) => runner.await,
410-
_ => unreachable!("timeout_instant must be Some iff self.timeout is Some"),
438+
None => runner.await,
411439
};
412440

413441
let elapsed = query_start.elapsed();
@@ -802,9 +830,10 @@ If you are using this API, you are probably doing something wrong."
802830
.serial_consistency
803831
.unwrap_or(execution_profile.serial_consistency);
804832

805-
let timeout = statement
833+
let timeouter = statement
806834
.get_request_timeout()
807-
.or(execution_profile.request_timeout);
835+
.or(execution_profile.request_timeout)
836+
.map(PageQueryTimeouter::new);
808837

809838
let page_size = statement.get_validated_page_size();
810839

@@ -862,7 +891,7 @@ If you are using this API, you are probably doing something wrong."
862891
query_consistency: consistency,
863892
load_balancing_policy,
864893
retry_session,
865-
timeout,
894+
timeouter,
866895
#[cfg(feature = "metrics")]
867896
metrics,
868897
paging_state: PagingState::start(),
@@ -895,10 +924,11 @@ If you are using this API, you are probably doing something wrong."
895924
.serial_consistency
896925
.unwrap_or(config.execution_profile.serial_consistency);
897926

898-
let timeout = config
927+
let timeouter = config
899928
.prepared
900929
.get_request_timeout()
901-
.or(config.execution_profile.request_timeout);
930+
.or(config.execution_profile.request_timeout)
931+
.map(PageQueryTimeouter::new);
902932

903933
let page_size = config.prepared.get_validated_page_size();
904934

@@ -996,7 +1026,7 @@ If you are using this API, you are probably doing something wrong."
9961026
query_consistency: consistency,
9971027
load_balancing_policy,
9981028
retry_session,
999-
timeout,
1029+
timeouter,
10001030
#[cfg(feature = "metrics")]
10011031
metrics: config.metrics,
10021032
paging_state: PagingState::start(),

0 commit comments

Comments
 (0)