diff --git a/scylla-proxy/src/proxy.rs b/scylla-proxy/src/proxy.rs index a43734ab7d..e5a8f5b9e9 100644 --- a/scylla-proxy/src/proxy.rs +++ b/scylla-proxy/src/proxy.rs @@ -1361,7 +1361,7 @@ impl ProxyWorker { event_registered_flag: Arc, ) { let shard = self.shard; - self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move { + self.run_until_interrupted("response_processor", |driver_addr, _, real_addr| async move { 'mainloop: loop { match responses_rx.recv().await { Some(response) => { @@ -1374,7 +1374,7 @@ impl ProxyWorker { let mut guard = response_rules.lock().unwrap(); '_ruleloop: for (i, response_rule) in guard.iter_mut().enumerate() { if response_rule.0.eval(&ctx) { - info!("Applying rule no={} to request ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard)); + info!("Applying rule no={} to response ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard)); debug!("-> Applied rule: {:?}", response_rule); debug!("-> To response: {:?}", ctx.opcode); trace!("{:?}", response); diff --git a/scylla/src/client/pager.rs b/scylla/src/client/pager.rs index 4d8588a647..2b8945111b 100644 --- a/scylla/src/client/pager.rs +++ b/scylla/src/client/pager.rs @@ -8,6 +8,7 @@ use std::ops::ControlFlow; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use std::time::Duration; use futures::Stream; use scylla_cql::Consistency; @@ -134,6 +135,50 @@ use crate::response::Coordinator; type PageSendAttemptedProof = SendAttemptedProof>; +mod timeouter { + use std::time::Duration; + + use tokio::time::Instant; + + /// Encapsulation of a timeout for paging queries. + pub(super) struct PageQueryTimeouter { + timeout: Duration, + timeout_instant: Instant, + } + + impl PageQueryTimeouter { + /// Creates a new PageQueryTimeouter with the given timeout duration, + /// starting from now. + pub(super) fn new(timeout: Duration) -> Self { + Self { + timeout, + timeout_instant: Instant::now() + timeout, + } + } + + /// Returns the timeout duration. + pub(super) fn timeout_duration(&self) -> Duration { + self.timeout + } + + /// Returns the instant at which the timeout will elapse. + /// + /// This can be used with `tokio::time::timeout_at`. + pub(super) fn deadline(&self) -> Instant { + self.timeout_instant + } + + /// Resets the timeout countdown. + /// + /// This should be called right before beginning first page fetch + /// and after each successful page fetch. + pub(super) fn reset(&mut self) { + self.timeout_instant = Instant::now() + self.timeout; + } + } +} +use timeouter::PageQueryTimeouter; + // PagerWorker works in the background to fetch pages // QueryPager receives them through a channel struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> { @@ -148,6 +193,7 @@ struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> { query_is_idempotent: bool, query_consistency: Consistency, retry_session: Box, + timeouter: Option, #[cfg(feature = "metrics")] metrics: Arc, @@ -178,6 +224,7 @@ where let mut current_consistency: Consistency = self.query_consistency; self.log_request_start(); + self.timeouter.as_mut().map(PageQueryTimeouter::reset); 'nodes_in_plan: for (node, shard) in query_plan { let span = trace_span!(parent: &self.parent_span, "Executing query", node = %node.address, shard = %shard); @@ -208,20 +255,23 @@ where Coordinator::new(node, node.sharder().is_some().then_some(shard), &connection); // Query pages until an error occurs - let queries_result: Result = self + let queries_result: Result< + Result, + RequestTimeoutError, + > = self .query_pages(&connection, current_consistency, node, coordinator.clone()) .instrument(span.clone()) .await; let request_error: RequestAttemptError = match queries_result { - Ok(proof) => { + Ok(Ok(proof)) => { trace!(parent: &span, "Request succeeded"); // query_pages returned Ok, so we are guaranteed // that it attempted to send at least one page // through self.sender and we can safely return now. return proof; } - Err(error) => { + Ok(Err(error)) => { trace!( parent: &span, error = %error, @@ -229,6 +279,20 @@ where ); error } + Err(RequestTimeoutError(timeout)) => { + let request_error = RequestError::RequestTimeout(timeout); + self.log_request_error(&request_error); + trace!( + parent: &span, + error = %request_error, + "Request timed out" + ); + let (proof, _) = self + .sender + .send(Err(NextPageError::RequestFailure(request_error))) + .await; + return proof; + } }; // Use retry policy to decide what to do next @@ -269,7 +333,7 @@ where // Although we are in an awkward situation (_iter // interface isn't meant for sending writes), // we must attempt to send something because - // the iterator expects it. + // QueryPager expects it. let (proof, _) = self .sender .send_empty_page(None, Some(coordinator.clone())) @@ -299,7 +363,7 @@ where consistency: Consistency, node: NodeRef<'_>, coordinator: Coordinator, - ) -> Result { + ) -> Result, RequestTimeoutError> { loop { let request_span = (self.span_creator)(); match self @@ -311,10 +375,24 @@ where &request_span, ) .instrument(request_span.span().clone()) - .await? + .await { - ControlFlow::Break(proof) => return Ok(proof), - ControlFlow::Continue(_) => {} + Ok(Ok(ControlFlow::Break(proof))) => { + // Successfully queried the last remaining page. + return Ok(Ok(proof)); + } + + Ok(Ok(ControlFlow::Continue(()))) => { + // Successfully queried one page, and there are more to fetch. + // Reset the timeout_instant for the next page fetch. + self.timeouter.as_mut().map(PageQueryTimeouter::reset); + } + Ok(Err(request_attempt_error)) => { + return Ok(Err(request_attempt_error)); + } + Err(request_timeout_error) => { + return Err(request_timeout_error); + } } } } @@ -326,7 +404,10 @@ where node: NodeRef<'_>, coordinator: Coordinator, request_span: &RequestSpan, - ) -> Result, RequestAttemptError> { + ) -> Result< + Result, RequestAttemptError>, + RequestTimeoutError, + > { #[cfg(feature = "metrics")] self.metrics.inc_total_paged_queries(); let query_start = std::time::Instant::now(); @@ -338,10 +419,25 @@ where ); self.log_attempt_start(connect_address); - let query_response = + let runner = async { (self.page_query)(connection.clone(), consistency, self.paging_state.clone()) .await - .and_then(QueryResponse::into_non_error_query_response); + .and_then(QueryResponse::into_non_error_query_response) + }; + let query_response = match self.timeouter { + Some(ref timeouter) => { + match tokio::time::timeout_at(timeouter.deadline(), runner).await { + Ok(res) => res, + Err(_) /* tokio::time::error::Elapsed */ => { + #[cfg(feature = "metrics")] + self.metrics.inc_request_timeouts(); + return Err(RequestTimeoutError(timeouter.timeout_duration())); + } + } + } + + None => runner.await, + }; let elapsed = query_start.elapsed(); @@ -373,7 +469,7 @@ where let (proof, res) = self.sender.send(Ok(received_page)).await; if res.is_err() { // channel was closed, QueryPager was dropped - should shutdown - return Ok(ControlFlow::Break(proof)); + return Ok(Ok(ControlFlow::Break(proof))); } match paging_state_response.into_paging_control_flow() { @@ -382,7 +478,7 @@ where } ControlFlow::Break(()) => { // Reached the last query, shutdown - return Ok(ControlFlow::Break(proof)); + return Ok(Ok(ControlFlow::Break(proof))); } } @@ -390,7 +486,7 @@ where self.retry_session.reset(); self.log_request_start(); - Ok(ControlFlow::Continue(())) + Ok(Ok(ControlFlow::Continue(()))) } Err(err) => { #[cfg(feature = "metrics")] @@ -401,7 +497,7 @@ where node, &err, ); - Err(err) + Ok(Err(err)) } Ok(NonErrorQueryResponse { response: NonErrorResponse::Result(_), @@ -416,7 +512,7 @@ where .sender .send_empty_page(tracing_id, Some(coordinator)) .await; - Ok(ControlFlow::Break(proof)) + Ok(Ok(ControlFlow::Break(proof))) } Ok(response) => { #[cfg(feature = "metrics")] @@ -429,7 +525,7 @@ where node, &err, ); - Err(err) + Ok(Err(err)) } } } @@ -521,6 +617,7 @@ where struct SingleConnectionPagerWorker { sender: ProvingSender>, fetcher: Fetcher, + timeout: Option, } impl SingleConnectionPagerWorker @@ -530,8 +627,8 @@ where { async fn work(mut self) -> PageSendAttemptedProof { match self.do_work().await { - Ok(proof) => proof, - Err(err) => { + Ok(Ok(proof)) => proof, + Ok(Err(err)) => { let (proof, _) = self .sender .send(Err(NextPageError::RequestFailure( @@ -540,14 +637,47 @@ where .await; proof } + Err(RequestTimeoutError(timeout)) => { + let (proof, _) = self + .sender + .send(Err(NextPageError::RequestFailure( + RequestError::RequestTimeout(timeout), + ))) + .await; + proof + } } } - async fn do_work(&mut self) -> Result { + async fn do_work( + &mut self, + ) -> Result, RequestTimeoutError> { let mut paging_state = PagingState::start(); loop { - let result = (self.fetcher)(paging_state).await?; - let response = result.into_non_error_query_response()?; + let runner = async { + (self.fetcher)(paging_state) + .await + .and_then(QueryResponse::into_non_error_query_response) + }; + let response_res = match self.timeout { + Some(timeout) => { + match tokio::time::timeout(timeout, runner).await { + Ok(res) => res, + Err(_) /* tokio::time::error::Elapsed */ => { + return Err(RequestTimeoutError(timeout)); + } + } + } + + None => runner.await, + }; + let response = match response_res { + Ok(resp) => resp, + Err(err) => { + return Ok(Err(err)); + } + }; + match response.response { NonErrorResponse::Result(result::Result::Rows((rows, paging_state_response))) => { let (proof, send_result) = self @@ -561,7 +691,7 @@ where if send_result.is_err() { // channel was closed, QueryPager was dropped - should shutdown - return Ok(proof); + return Ok(Ok(proof)); } match paging_state_response.into_paging_control_flow() { @@ -570,7 +700,7 @@ where } ControlFlow::Break(()) => { // Reached the last query, shutdown - return Ok(proof); + return Ok(Ok(proof)); } } } @@ -580,12 +710,12 @@ where // We must attempt to send something because the iterator expects it. let (proof, _) = self.sender.send_empty_page(response.tracing_id, None).await; - return Ok(proof); + return Ok(Ok(proof)); } _ => { - return Err(RequestAttemptError::UnexpectedResponse( + return Ok(Err(RequestAttemptError::UnexpectedResponse( response.response.to_response_kind(), - )); + ))); } } } @@ -735,6 +865,11 @@ If you are using this API, you are probably doing something wrong." .serial_consistency .unwrap_or(execution_profile.serial_consistency); + let timeouter = statement + .get_request_timeout() + .or(execution_profile.request_timeout) + .map(PageQueryTimeouter::new); + let page_size = statement.get_validated_page_size(); let routing_info = RoutingInfo { @@ -791,6 +926,7 @@ If you are using this API, you are probably doing something wrong." query_consistency: consistency, load_balancing_policy, retry_session, + timeouter, #[cfg(feature = "metrics")] metrics, paging_state: PagingState::start(), @@ -823,6 +959,12 @@ If you are using this API, you are probably doing something wrong." .serial_consistency .unwrap_or(config.execution_profile.serial_consistency); + let timeouter = config + .prepared + .get_request_timeout() + .or(config.execution_profile.request_timeout) + .map(PageQueryTimeouter::new); + let page_size = config.prepared.get_validated_page_size(); let load_balancing_policy = Arc::clone( @@ -919,6 +1061,7 @@ If you are using this API, you are probably doing something wrong." query_consistency: consistency, load_balancing_policy, retry_session, + timeouter, #[cfg(feature = "metrics")] metrics: config.metrics, paging_state: PagingState::start(), @@ -944,6 +1087,12 @@ If you are using this API, you are probably doing something wrong." let (sender, receiver) = mpsc::channel::>(1); let page_size = query.get_validated_page_size(); + let timeout = query.get_request_timeout().or_else(|| { + query + .get_execution_profile_handle()? + .access() + .request_timeout + }); let worker_task = async move { let worker = SingleConnectionPagerWorker { @@ -957,6 +1106,7 @@ If you are using this API, you are probably doing something wrong." paging_state, ) }, + timeout, }; worker.work().await }; @@ -974,6 +1124,12 @@ If you are using this API, you are probably doing something wrong." let (sender, receiver) = mpsc::channel::>(1); let page_size = prepared.get_validated_page_size(); + let timeout = prepared.get_request_timeout().or_else(|| { + prepared + .get_execution_profile_handle()? + .access() + .request_timeout + }); let worker_task = async move { let worker = SingleConnectionPagerWorker { @@ -988,6 +1144,7 @@ If you are using this API, you are probably doing something wrong." paging_state, ) }, + timeout, }; worker.work().await }; @@ -1138,6 +1295,14 @@ where } } +/// Failed to run a request within a provided client timeout. +#[derive(Error, Debug, Clone)] +#[error( + "Request execution exceeded a client timeout of {}ms", + std::time::Duration::as_millis(.0) +)] +struct RequestTimeoutError(std::time::Duration); + /// An error returned that occurred during next page fetch. #[derive(Error, Debug, Clone)] #[non_exhaustive] diff --git a/scylla/tests/integration/session/pager.rs b/scylla/tests/integration/session/pager.rs index 5368b1d1c9..a04e372c9b 100644 --- a/scylla/tests/integration/session/pager.rs +++ b/scylla/tests/integration/session/pager.rs @@ -1,22 +1,33 @@ -use std::sync::{ - Arc, - atomic::{AtomicBool, Ordering}, +use std::{ + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, }; use assert_matches::assert_matches; use futures::{StreamExt as _, TryStreamExt as _}; -use scylla::errors::{NextPageError, NextRowError}; use scylla::{ client::execution_profile::ExecutionProfile, policies::retry::{RequestInfo, RetryDecision, RetryPolicy, RetrySession}, statement::Statement, value::Row, }; +use scylla::{ + client::{session::Session, session_builder::SessionBuilder}, + errors::{NextPageError, NextRowError, PagerExecutionError, RequestError}, +}; use scylla_cql::Consistency; +use scylla_proxy::{ + Condition, ProxyError, Reaction as _, RequestOpcode, RequestReaction, RequestRule, WorkerError, + example_db_errors, +}; +use tracing::info; use crate::utils::{ PerformDDL as _, create_new_session_builder, scylla_supports_tablets, setup_tracing, - unique_keyspace_name, + test_with_3_node_cluster, unique_keyspace_name, }; // Reproduces the problem with execute_iter mentioned in #608. @@ -220,3 +231,179 @@ async fn test_iter_methods_when_altering_table() { session.ddl(format!("DROP KEYSPACE {ks}")).await.unwrap(); } + +#[tokio::test] +#[cfg_attr(scylla_cloud_tests, ignore)] +async fn test_pager_timeouts() { + setup_tracing(); + + let res = test_with_3_node_cluster( + scylla_proxy::ShardAwareness::QueryNode, + |proxy_uris, translation_map, mut running_proxy| async move { + /* Prepare phase */ + let ks = unique_keyspace_name(); + + let session: Session = SessionBuilder::new() + .known_node(proxy_uris[0].as_str()) + .address_translator(Arc::new(translation_map)) + .build() + .await + .unwrap(); + + session + .ddl(format!( + "CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}" + )) + .await + .unwrap(); + session.use_keyspace(ks.clone(), true).await.unwrap(); + + session + .ddl("CREATE TABLE IF NOT EXISTS t (a int PRIMARY KEY)") + .await + .unwrap(); + + for i in 0..5 { + session + .query_unpaged("INSERT INTO t (a) VALUES (?)", (i,)) + .await + .unwrap(); + } + + let mut prepared = session.prepare("SELECT a FROM t").await.unwrap(); + // Important to have multiple pages. + prepared.set_page_size(1); + // Important for retries to fire. + prepared.set_is_idempotent(true); + + /* Test phase */ + + // Case 1: the first page fetch times out. + { + let timeout = Duration::from_secs(1); + prepared.set_request_timeout(Some(timeout)); + + running_proxy.running_nodes.iter_mut().for_each(|node| { + node.change_request_rules(Some(vec![ + RequestRule( + Condition::RequestOpcode(RequestOpcode::Execute) + .and(Condition::not(Condition::ConnectionRegisteredAnyEvent)), + RequestReaction::delay(timeout + Duration::from_secs(1)) + ) + ])); + }); + + let pager_err = session.execute_iter(prepared.clone(), ()).await.unwrap_err(); + let PagerExecutionError::NextPageError(NextPageError::RequestFailure( + RequestError::RequestTimeout(got_timeout), + )) = pager_err + else { + panic!("Expected RequestTimeout error, got: {:?}", pager_err); + }; + assert_eq!(got_timeout, timeout); + info!("Case 1 passed."); + } + + // Case 2: the second page fetch times out. + { + let timeout = Duration::from_secs(1); + prepared.set_request_timeout(Some(timeout)); + + running_proxy.running_nodes.iter_mut().for_each(|node| { + node.change_request_rules(Some(vec![ + // Pass one frame, then delay all subsequent ones. + RequestRule( + Condition::RequestOpcode(RequestOpcode::Execute) + .and(Condition::not(Condition::ConnectionRegisteredAnyEvent)) + .and(Condition::TrueForLimitedTimes(1)), + RequestReaction::noop() + ), + RequestRule( + Condition::RequestOpcode(RequestOpcode::Execute) + .and(Condition::not(Condition::ConnectionRegisteredAnyEvent)), + RequestReaction::delay(timeout + Duration::from_secs(1)) + ) + ])); + }); + + let mut row_stream = session + .execute_iter(prepared.clone(), ()) + .await + .unwrap() + .rows_stream::<(i32,)>() + .unwrap(); + + // Observation that is not critical to the test, but good to note: + // at this point, at most two pages have been fetched: + // - the first page, fetched eagerly by execute_iter; + // - possibly the second page, fetched lazily by rows_stream; + // - no more pages may have been fetched yet, because the second page would be + // stuck on channel.send(), waiting for us to consume the first row. + + // First page (1 row) must have been fetched successfully. + let (_a,) = row_stream.next().await.unwrap().unwrap(); + + // The second page fetch must time out. + let row_err = row_stream.next().await.unwrap().unwrap_err(); + let NextRowError::NextPageError(NextPageError::RequestFailure( + RequestError::RequestTimeout(got_timeout), + )) = row_err + else { + panic!("Expected RequestTimeout error, got: {:?}", row_err); + }; + assert_eq!(got_timeout, timeout); + info!("Case 2 passed."); + } + + // Case 3: retries' cumulative duration exceed the timeout. + { + // Here, each retry will be delayed by 200ms. + // With a 500ms timeout, this means that after 3 retries (600ms total delay), + // the timeout will be exceeded. + let per_retry_delay = Duration::from_millis(200); + let timeout = Duration::from_millis(500); + + // Set timeout through the execution profile. + { + let profile = ExecutionProfile::builder().request_timeout(Some(timeout)).build(); + let handle = profile.into_handle(); + prepared.set_execution_profile_handle(Some(handle)); + prepared.set_request_timeout(None); + } + + running_proxy.running_nodes.iter_mut().for_each(|node| { + node.change_request_rules(Some(vec![ + RequestRule( + Condition::RequestOpcode(RequestOpcode::Execute) + .and(Condition::not(Condition::ConnectionRegisteredAnyEvent)), + RequestReaction::forge_with_error_lazy_delay( + Box::new(example_db_errors::overloaded), + Some(per_retry_delay)) + ) + ])); + }); + + let pager_err = session.execute_iter(prepared, ()).await.unwrap_err(); + let PagerExecutionError::NextPageError(NextPageError::RequestFailure( + RequestError::RequestTimeout(got_timeout), + )) = pager_err + else { + panic!("Expected RequestTimeout error, got: {:?}", pager_err); + }; + assert_eq!(got_timeout, timeout); + info!("Case 3 passed."); + } + + /* Teardown */ + session.ddl(format!("DROP KEYSPACE {ks}")).await.unwrap(); + + running_proxy + }, + ).await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +}