Skip to content

Commit a140f7f

Browse files
committed
iterator: narrow error type of internal items
Narrowed the error types in multiple places in internal API of iterator module. Now the error type we manipulate on mainly is `NextPageError` (instead of `QueryError`). I did not change the return type of public methods yet. I want to do it in a separate commit.
1 parent b7e8ec5 commit a140f7f

File tree

4 files changed

+49
-35
lines changed

4 files changed

+49
-35
lines changed

scylla/src/client/pager.rs

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use crate::observability::history::{self, HistoryListener};
3737
use crate::observability::metrics::Metrics;
3838
use crate::policies::load_balancing::{self, RoutingInfo};
3939
use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
40+
use crate::prepared_statement::PartitionKeyError;
4041
use crate::response::query_result::ColumnSpecs;
4142
use crate::response::{NonErrorQueryResponse, QueryResponse};
4243
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
@@ -78,9 +79,7 @@ mod checked_channel_sender {
7879
use tokio::sync::mpsc;
7980
use uuid::Uuid;
8081

81-
use crate::errors::QueryError;
82-
83-
use super::ReceivedPage;
82+
use super::{NextPageError, ReceivedPage};
8483

8584
/// A value whose existence proves that there was an attempt
8685
/// to send an item of type T through a channel.
@@ -105,7 +104,7 @@ mod checked_channel_sender {
105104
}
106105
}
107106

108-
type ResultPage = Result<ReceivedPage, QueryError>;
107+
type ResultPage = Result<ReceivedPage, NextPageError>;
109108

110109
impl ProvingSender<ResultPage> {
111110
pub(crate) async fn send_empty_page(
@@ -126,12 +125,12 @@ mod checked_channel_sender {
126125

127126
use checked_channel_sender::{ProvingSender, SendAttemptedProof};
128127

129-
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError>>;
128+
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, NextPageError>>;
130129

131130
// PagerWorker works in the background to fetch pages
132131
// QueryPager receives them through a channel
133132
struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
134-
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
133+
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
135134

136135
// Closure used to perform a single page query
137136
// AsyncFn(Arc<Connection>, Option<Arc<[u8]>>) -> Result<QueryResponse, RequestAttemptError>
@@ -266,7 +265,10 @@ where
266265
}
267266

268267
self.log_request_error(&last_error);
269-
let (proof, _) = self.sender.send(Err(last_error.into_query_error())).await;
268+
let (proof, _) = self
269+
.sender
270+
.send(Err(NextPageError::RequestFailure(last_error)))
271+
.await;
270272
proof
271273
}
272274

@@ -476,7 +478,7 @@ where
476478
/// any complicated logic related to retries, it just fetches pages from
477479
/// a single connection.
478480
struct SingleConnectionPagerWorker<Fetcher> {
479-
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
481+
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
480482
fetcher: Fetcher,
481483
}
482484

@@ -489,7 +491,12 @@ where
489491
match self.do_work().await {
490492
Ok(proof) => proof,
491493
Err(err) => {
492-
let (proof, _) = self.sender.send(Err(err.into_query_error())).await;
494+
let (proof, _) = self
495+
.sender
496+
.send(Err(NextPageError::RequestFailure(
497+
RequestError::LastAttemptError(err),
498+
)))
499+
.await;
493500
proof
494501
}
495502
}
@@ -559,7 +566,7 @@ where
559566
/// is not the intended target type.
560567
pub struct QueryPager {
561568
current_page: RawRowLendingIterator,
562-
page_receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
569+
page_receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
563570
tracing_ids: Vec<Uuid>,
564571
}
565572

@@ -581,7 +588,7 @@ impl QueryPager {
581588
let res = std::future::poll_fn(|cx| Pin::new(&mut *self).poll_fill_page(cx)).await;
582589
match res {
583590
Some(Ok(())) => {}
584-
Some(Err(err)) => return Some(Err(err)),
591+
Some(Err(err)) => return Some(Err(err.into())),
585592
None => return None,
586593
}
587594

@@ -598,7 +605,7 @@ impl QueryPager {
598605
fn poll_fill_page<'r>(
599606
mut self: Pin<&'r mut Self>,
600607
cx: &mut Context<'_>,
601-
) -> Poll<Option<Result<(), QueryError>>> {
608+
) -> Poll<Option<Result<(), NextRowError>>> {
602609
if !self.is_current_page_exhausted() {
603610
return Poll::Ready(Some(Ok(())));
604611
}
@@ -621,14 +628,11 @@ impl QueryPager {
621628
fn poll_next_page<'r>(
622629
mut self: Pin<&'r mut Self>,
623630
cx: &mut Context<'_>,
624-
) -> Poll<Option<Result<(), QueryError>>> {
631+
) -> Poll<Option<Result<(), NextRowError>>> {
625632
let mut s = self.as_mut();
626633

627634
let received_page = ready_some_ok!(Pin::new(&mut s.page_receiver).poll_recv(cx));
628635

629-
// TODO: see my other comment next to QueryError::NextRowError
630-
// This is the place where conversion happens. To fix this, we need to refactor error types in iterator API.
631-
// The `page_receiver`'s error type should be narrowed from QueryError to some other error type.
632636
let raw_rows_with_deserialized_metadata =
633637
received_page.rows.deserialize_metadata().map_err(|err| {
634638
NextRowError::NextPageError(NextPageError::ResultMetadataParseError(err))
@@ -683,8 +687,8 @@ impl QueryPager {
683687
execution_profile: Arc<ExecutionProfileInner>,
684688
cluster_data: Arc<ClusterState>,
685689
metrics: Arc<Metrics>,
686-
) -> Result<Self, QueryError> {
687-
let (sender, receiver) = mpsc::channel(1);
690+
) -> Result<Self, NextRowError> {
691+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
688692

689693
let consistency = query
690694
.config
@@ -762,8 +766,8 @@ impl QueryPager {
762766

763767
pub(crate) async fn new_for_prepared_statement(
764768
config: PreparedIteratorConfig,
765-
) -> Result<Self, QueryError> {
766-
let (sender, receiver) = mpsc::channel(1);
769+
) -> Result<Self, NextRowError> {
770+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
767771

768772
let consistency = config
769773
.prepared
@@ -798,7 +802,7 @@ impl QueryPager {
798802
Ok(res) => res.unzip(),
799803
Err(err) => {
800804
let (proof, _res) = ProvingSender::from(sender)
801-
.send(Err(err.into_query_error()))
805+
.send(Err(NextPageError::PartitionKeyError(err)))
802806
.await;
803807
return proof;
804808
}
@@ -885,8 +889,8 @@ impl QueryPager {
885889
connection: Arc<Connection>,
886890
consistency: Consistency,
887891
serial_consistency: Option<SerialConsistency>,
888-
) -> Result<Self, QueryError> {
889-
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
892+
) -> Result<Self, NextRowError> {
893+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
890894

891895
let page_size = query.get_validated_page_size();
892896

@@ -915,8 +919,8 @@ impl QueryPager {
915919
connection: Arc<Connection>,
916920
consistency: Consistency,
917921
serial_consistency: Option<SerialConsistency>,
918-
) -> Result<Self, QueryError> {
919-
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
922+
) -> Result<Self, NextRowError> {
923+
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);
920924

921925
let page_size = prepared.get_validated_page_size();
922926

@@ -942,8 +946,8 @@ impl QueryPager {
942946

943947
async fn new_from_worker_future(
944948
worker_task: impl Future<Output = PageSendAttemptedProof> + Send + 'static,
945-
mut receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
946-
) -> Result<Self, QueryError> {
949+
mut receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
950+
) -> Result<Self, NextRowError> {
947951
tokio::task::spawn(worker_task);
948952

949953
// This unwrap is safe because:
@@ -1053,12 +1057,17 @@ where
10531057
#[derive(Error, Debug, Clone)]
10541058
#[non_exhaustive]
10551059
pub enum NextPageError {
1060+
/// PK extraction and/or token calculation error. Applies only for prepared statements.
1061+
#[error("Failed to extract PK and compute token required for routing: {0}")]
1062+
PartitionKeyError(#[from] PartitionKeyError),
1063+
1064+
/// Failed to run a request responsible for fetching new page.
1065+
#[error(transparent)]
1066+
RequestFailure(#[from] RequestError),
1067+
10561068
/// Failed to deserialize result metadata associated with next page response.
10571069
#[error("Failed to deserialize result metadata associated with next page response: {0}")]
10581070
ResultMetadataParseError(#[from] ResultMetadataAndRowsCountParseError),
1059-
// TODO: This should also include a variant representing an error that occurred during
1060-
// query that fetches the next page. However, as of now, it would require that we include QueryError here.
1061-
// This would introduce a cyclic dependency: QueryError -> NextRowError -> NextPageError -> QueryError.
10621071
}
10631072

10641073
/// An error returned by async iterator API.

scylla/src/client/session.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,7 @@ where
12351235
self.metrics.clone(),
12361236
)
12371237
.await
1238+
.map_err(QueryError::from)
12381239
} else {
12391240
// Making QueryPager::new_for_query work with values is too hard (if even possible)
12401241
// so instead of sending one prepare to a specific connection on each iterator query,
@@ -1249,6 +1250,7 @@ where
12491250
metrics: self.metrics.clone(),
12501251
})
12511252
.await
1253+
.map_err(QueryError::from)
12521254
}
12531255
}
12541256

@@ -1504,6 +1506,7 @@ where
15041506
metrics: self.metrics.clone(),
15051507
})
15061508
.await
1509+
.map_err(QueryError::from)
15071510
}
15081511

15091512
async fn do_batch(

scylla/src/cluster/metadata.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ where
982982
let mut query = Query::new(query_str);
983983
query.set_page_size(METADATA_QUERY_PAGE_SIZE);
984984

985-
conn.query_iter(query).await
985+
conn.query_iter(query).await.map_err(QueryError::from)
986986
} else {
987987
let keyspaces = &[keyspaces_to_fetch] as &[&[String]];
988988
let query_str = format!("{query_str} where keyspace_name in ?");
@@ -995,7 +995,9 @@ where
995995
.await
996996
.map_err(RequestAttemptError::into_query_error)?;
997997
let serialized_values = prepared.serialize_values(&keyspaces)?;
998-
conn.execute_iter(prepared, serialized_values).await
998+
conn.execute_iter(prepared, serialized_values)
999+
.await
1000+
.map_err(QueryError::from)
9991001
}
10001002
}
10011003

scylla/src/network/connection.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::authentication::AuthenticatorProvider;
22
use crate::batch::{Batch, BatchStatement};
3-
use crate::client::pager::QueryPager;
3+
use crate::client::pager::{NextRowError, QueryPager};
44
use crate::client::Compression;
55
use crate::client::SelfIdentity;
66
#[cfg(feature = "cloud")]
@@ -988,7 +988,7 @@ impl Connection {
988988
pub(crate) async fn query_iter(
989989
self: Arc<Self>,
990990
query: Query,
991-
) -> Result<QueryPager, QueryError> {
991+
) -> Result<QueryPager, NextRowError> {
992992
let consistency = query
993993
.config
994994
.determine_consistency(self.config.default_consistency);
@@ -1004,7 +1004,7 @@ impl Connection {
10041004
self: Arc<Self>,
10051005
prepared_statement: PreparedStatement,
10061006
values: SerializedValues,
1007-
) -> Result<QueryPager, QueryError> {
1007+
) -> Result<QueryPager, NextRowError> {
10081008
let consistency = prepared_statement
10091009
.config
10101010
.determine_consistency(self.config.default_consistency);

0 commit comments

Comments
 (0)