Skip to content

Commit d40e184

Browse files
authored
Merge pull request #1160 from muzarski/iterator-errors-refactor
errors: pager API errors refactor
2 parents 918e522 + 04c7eec commit d40e184

File tree

6 files changed

+117
-93
lines changed

6 files changed

+117
-93
lines changed

scylla/src/client/pager.rs

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ use crate::cluster::{ClusterState, NodeRef};
2929
#[allow(deprecated)]
3030
use crate::cql_to_rust::{FromRow, FromRowError};
3131
use crate::deserialize::DeserializeOwnedRow;
32-
use crate::errors::{ProtocolError, RequestError};
33-
use crate::errors::{QueryError, RequestAttemptError};
32+
use crate::errors::{RequestAttemptError, RequestError};
3433
use crate::frame::response::result;
3534
use crate::network::Connection;
3635
use crate::observability::driver_tracing::RequestSpan;
3736
use crate::observability::history::{self, HistoryListener};
3837
use crate::observability::metrics::Metrics;
3938
use crate::policies::load_balancing::{self, RoutingInfo};
4039
use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
40+
use crate::prepared_statement::PartitionKeyError;
4141
use crate::response::query_result::ColumnSpecs;
4242
use crate::response::{NonErrorQueryResponse, QueryResponse};
4343
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
@@ -79,9 +79,7 @@ mod checked_channel_sender {
7979
use tokio::sync::mpsc;
8080
use uuid::Uuid;
8181

82-
use crate::errors::QueryError;
83-
84-
use super::ReceivedPage;
82+
use super::{NextPageError, ReceivedPage};
8583

8684
/// A value whose existence proves that there was an attempt
8785
/// to send an item of type T through a channel.
@@ -106,7 +104,7 @@ mod checked_channel_sender {
106104
}
107105
}
108106

109-
type ResultPage = Result<ReceivedPage, QueryError>;
107+
type ResultPage = Result<ReceivedPage, NextPageError>;
110108

111109
impl ProvingSender<ResultPage> {
112110
pub(crate) async fn send_empty_page(
@@ -127,12 +125,12 @@ mod checked_channel_sender {
127125

128126
use checked_channel_sender::{ProvingSender, SendAttemptedProof};
129127

130-
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError>>;
128+
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, NextPageError>>;
131129

132130
// PagerWorker works in the background to fetch pages
133131
// QueryPager receives them through a channel
134132
struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
135-
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
133+
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
136134

137135
// Closure used to perform a single page query
138136
// AsyncFn(Arc<Connection>, Option<Arc<[u8]>>) -> Result<QueryResponse, RequestAttemptError>
@@ -267,7 +265,10 @@ where
267265
}
268266

269267
self.log_request_error(&last_error);
270-
let (proof, _) = self.sender.send(Err(last_error.into_query_error())).await;
268+
let (proof, _) = self
269+
.sender
270+
.send(Err(NextPageError::RequestFailure(last_error)))
271+
.await;
271272
proof
272273
}
273274

@@ -477,7 +478,7 @@ where
477478
/// any complicated logic related to retries, it just fetches pages from
478479
/// a single connection.
479480
struct SingleConnectionPagerWorker<Fetcher> {
480-
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
481+
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
481482
fetcher: Fetcher,
482483
}
483484

@@ -490,21 +491,22 @@ where
490491
match self.do_work().await {
491492
Ok(proof) => proof,
492493
Err(err) => {
493-
let (proof, _) = self.sender.send(Err(err)).await;
494+
let (proof, _) = self
495+
.sender
496+
.send(Err(NextPageError::RequestFailure(
497+
RequestError::LastAttemptError(err),
498+
)))
499+
.await;
494500
proof
495501
}
496502
}
497503
}
498504

499-
async fn do_work(&mut self) -> Result<PageSendAttemptedProof, QueryError> {
505+
async fn do_work(&mut self) -> Result<PageSendAttemptedProof, RequestAttemptError> {
500506
let mut paging_state = PagingState::start();
501507
loop {
502-
let result = (self.fetcher)(paging_state)
503-
.await
504-
.map_err(RequestAttemptError::into_query_error)?;
505-
let response = result
506-
.into_non_error_query_response()
507-
.map_err(RequestAttemptError::into_query_error)?;
508+
let result = (self.fetcher)(paging_state).await?;
509+
let response = result.into_non_error_query_response()?;
508510
match response.response {
509511
NonErrorResponse::Result(result::Result::Rows((rows, paging_state_response))) => {
510512
let (proof, send_result) = self
@@ -539,10 +541,9 @@ where
539541
return Ok(proof);
540542
}
541543
_ => {
542-
return Err(ProtocolError::UnexpectedResponse(
544+
return Err(RequestAttemptError::UnexpectedResponse(
543545
response.response.to_response_kind(),
544-
)
545-
.into());
546+
));
546547
}
547548
}
548549
}
@@ -565,7 +566,7 @@ where
565566
/// is not the intended target type.
566567
pub struct QueryPager {
567568
current_page: RawRowLendingIterator,
568-
page_receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
569+
page_receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
569570
tracing_ids: Vec<Uuid>,
570571
}
571572

@@ -583,7 +584,7 @@ impl QueryPager {
583584
/// borrows from self.
584585
///
585586
/// This is cancel-safe.
586-
async fn next(&mut self) -> Option<Result<ColumnIterator, QueryError>> {
587+
async fn next(&mut self) -> Option<Result<ColumnIterator, NextRowError>> {
587588
let res = std::future::poll_fn(|cx| Pin::new(&mut *self).poll_fill_page(cx)).await;
588589
match res {
589590
Some(Ok(())) => {}
@@ -596,15 +597,15 @@ impl QueryPager {
596597
self.current_page
597598
.next()
598599
.unwrap()
599-
.map_err(|err| NextRowError::RowDeserializationError(err).into()),
600+
.map_err(NextRowError::RowDeserializationError),
600601
)
601602
}
602603

603604
/// Tries to acquire a non-empty page, if current page is exhausted.
604605
fn poll_fill_page<'r>(
605606
mut self: Pin<&'r mut Self>,
606607
cx: &mut Context<'_>,
607-
) -> Poll<Option<Result<(), QueryError>>> {
608+
) -> Poll<Option<Result<(), NextRowError>>> {
608609
if !self.is_current_page_exhausted() {
609610
return Poll::Ready(Some(Ok(())));
610611
}
@@ -627,14 +628,11 @@ impl QueryPager {
627628
fn poll_next_page<'r>(
628629
mut self: Pin<&'r mut Self>,
629630
cx: &mut Context<'_>,
630-
) -> Poll<Option<Result<(), QueryError>>> {
631+
) -> Poll<Option<Result<(), NextRowError>>> {
631632
let mut s = self.as_mut();
632633

633634
let received_page = ready_some_ok!(Pin::new(&mut s.page_receiver).poll_recv(cx));
634635

635-
// TODO: see my other comment next to QueryError::NextRowError
636-
// This is the place where conversion happens. To fix this, we need to refactor error types in iterator API.
637-
// The `page_receiver`'s error type should be narrowed from QueryError to some other error type.
638636
let raw_rows_with_deserialized_metadata =
639637
received_page.rows.deserialize_metadata().map_err(|err| {
640638
NextRowError::NextPageError(NextPageError::ResultMetadataParseError(err))
@@ -689,8 +687,8 @@ impl QueryPager {
689687
execution_profile: Arc<ExecutionProfileInner>,
690688
cluster_data: Arc<ClusterState>,
691689
metrics: Arc<Metrics>,
692-
) -> Result<Self, QueryError> {
693-
let (sender, receiver) = mpsc::channel(1);
690+
) -> Result<Self, NextRowError> {
691+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
694692

695693
let consistency = query
696694
.config
@@ -768,8 +766,8 @@ impl QueryPager {
768766

769767
pub(crate) async fn new_for_prepared_statement(
770768
config: PreparedIteratorConfig,
771-
) -> Result<Self, QueryError> {
772-
let (sender, receiver) = mpsc::channel(1);
769+
) -> Result<Self, NextRowError> {
770+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
773771

774772
let consistency = config
775773
.prepared
@@ -803,7 +801,9 @@ impl QueryPager {
803801
) {
804802
Ok(res) => res.unzip(),
805803
Err(err) => {
806-
let (proof, _res) = ProvingSender::from(sender).send(Err(err)).await;
804+
let (proof, _res) = ProvingSender::from(sender)
805+
.send(Err(NextPageError::PartitionKeyError(err)))
806+
.await;
807807
return proof;
808808
}
809809
};
@@ -889,8 +889,8 @@ impl QueryPager {
889889
connection: Arc<Connection>,
890890
consistency: Consistency,
891891
serial_consistency: Option<SerialConsistency>,
892-
) -> Result<Self, QueryError> {
893-
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
892+
) -> Result<Self, NextRowError> {
893+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
894894

895895
let page_size = query.get_validated_page_size();
896896

@@ -919,8 +919,8 @@ impl QueryPager {
919919
connection: Arc<Connection>,
920920
consistency: Consistency,
921921
serial_consistency: Option<SerialConsistency>,
922-
) -> Result<Self, QueryError> {
923-
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
922+
) -> Result<Self, NextRowError> {
923+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
924924

925925
let page_size = prepared.get_validated_page_size();
926926

@@ -946,8 +946,8 @@ impl QueryPager {
946946

947947
async fn new_from_worker_future(
948948
worker_task: impl Future<Output = PageSendAttemptedProof> + Send + 'static,
949-
mut receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
950-
) -> Result<Self, QueryError> {
949+
mut receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
950+
) -> Result<Self, NextRowError> {
951951
tokio::task::spawn(worker_task);
952952

953953
// This unwrap is safe because:
@@ -1035,14 +1035,14 @@ impl<RowT> Stream for TypedRowStream<RowT>
10351035
where
10361036
RowT: DeserializeOwnedRow,
10371037
{
1038-
type Item = Result<RowT, QueryError>;
1038+
type Item = Result<RowT, NextRowError>;
10391039

10401040
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
10411041
let next_fut = async {
10421042
self.raw_row_lending_stream.next().await.map(|res| {
10431043
res.and_then(|column_iterator| {
10441044
<RowT as DeserializeRow>::deserialize(column_iterator)
1045-
.map_err(|err| NextRowError::RowDeserializationError(err).into())
1045+
.map_err(NextRowError::RowDeserializationError)
10461046
})
10471047
})
10481048
};
@@ -1057,12 +1057,17 @@ where
10571057
#[derive(Error, Debug, Clone)]
10581058
#[non_exhaustive]
10591059
pub enum NextPageError {
1060+
/// PK extraction and/or token calculation error. Applies only for prepared statements.
1061+
#[error("Failed to extract PK and compute token required for routing: {0}")]
1062+
PartitionKeyError(#[from] PartitionKeyError),
1063+
1064+
/// Failed to run a request responsible for fetching new page.
1065+
#[error(transparent)]
1066+
RequestFailure(#[from] RequestError),
1067+
10601068
/// Failed to deserialize result metadata associated with next page response.
10611069
#[error("Failed to deserialize result metadata associated with next page response: {0}")]
10621070
ResultMetadataParseError(#[from] ResultMetadataAndRowsCountParseError),
1063-
// TODO: This should also include a variant representing an error that occurred during
1064-
// query that fetches the next page. However, as of now, it would require that we include QueryError here.
1065-
// This would introduce a cyclic dependency: QueryError -> NextRowError -> NextPageError -> QueryError.
10661071
}
10671072

10681073
/// An error returned by async iterator API.
@@ -1172,7 +1177,7 @@ mod legacy {
11721177
pub enum LegacyNextRowError {
11731178
/// Query to fetch next page has failed
11741179
#[error(transparent)]
1175-
QueryError(#[from] QueryError),
1180+
NextRowError(#[from] NextRowError),
11761181

11771182
/// Parsing values in row as given types failed
11781183
#[error(transparent)]

scylla/src/client/session.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::policies::host_filter::HostFilter;
3131
use crate::policies::load_balancing::{self, RoutingInfo};
3232
use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
3333
use crate::policies::speculative_execution;
34-
use crate::prepared_statement::PreparedStatement;
34+
use crate::prepared_statement::{PartitionKeyError, PreparedStatement};
3535
use crate::query::Query;
3636
#[allow(deprecated)]
3737
use crate::response::legacy_query_result::LegacyQueryResult;
@@ -1235,6 +1235,7 @@ where
12351235
self.metrics.clone(),
12361236
)
12371237
.await
1238+
.map_err(QueryError::from)
12381239
} else {
12391240
// Making QueryPager::new_for_query work with values is too hard (if even possible)
12401241
// so instead of sending one prepare to a specific connection on each iterator query,
@@ -1249,6 +1250,7 @@ where
12491250
metrics: self.metrics.clone(),
12501251
})
12511252
.await
1253+
.map_err(QueryError::from)
12521254
}
12531255
}
12541256

@@ -1394,7 +1396,8 @@ where
13941396
let paging_state_ref = &paging_state;
13951397

13961398
let (partition_key, token) = prepared
1397-
.extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref)?
1399+
.extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref)
1400+
.map_err(PartitionKeyError::into_query_error)?
13981401
.unzip();
13991402

14001403
let execution_profile = prepared
@@ -1503,6 +1506,7 @@ where
15031506
metrics: self.metrics.clone(),
15041507
})
15051508
.await
1509+
.map_err(QueryError::from)
15061510
}
15071511

15081512
async fn do_batch(

0 commit comments

Comments
 (0)