Skip to content

Commit 3b845f3

Browse files
authored
Merge pull request #1449 from Lorak-mmk/pager-each-page-typecheck
Typecheck each page in execute/query_iter
2 parents 8cd86f3 + a3af3d6 commit 3b845f3

File tree

3 files changed

+126
-14
lines changed

3 files changed

+126
-14
lines changed

scylla-cql/src/deserialize/row.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ pub struct RawColumn<'frame, 'metadata> {
2929
/// Iterates over columns of a single row.
3030
#[derive(Clone, Debug)]
3131
pub struct ColumnIterator<'frame, 'metadata> {
32-
specs: std::iter::Enumerate<std::slice::Iter<'metadata, ColumnSpec<'metadata>>>,
32+
index: std::ops::RangeFrom<usize>,
33+
specs: std::slice::Iter<'metadata, ColumnSpec<'metadata>>,
3334
slice: FrameSlice<'frame>,
3435
}
3536

@@ -41,7 +42,8 @@ impl<'frame, 'metadata> ColumnIterator<'frame, 'metadata> {
4142
#[inline]
4243
pub fn new(specs: &'metadata [ColumnSpec<'metadata>], slice: FrameSlice<'frame>) -> Self {
4344
Self {
44-
specs: specs.iter().enumerate(),
45+
index: 0usize..,
46+
specs: specs.iter(),
4547
slice,
4648
}
4749
}
@@ -52,14 +54,26 @@ impl<'frame, 'metadata> ColumnIterator<'frame, 'metadata> {
5254
pub fn columns_remaining(&self) -> usize {
5355
self.specs.len()
5456
}
57+
58+
/// Performs a type check (see [DeserializeRow::type_check]) on remaining columns.
59+
#[inline]
60+
pub fn type_check<RowT: DeserializeRow<'frame, 'metadata>>(
61+
&self,
62+
) -> Result<(), TypeCheckError> {
63+
<RowT as DeserializeRow<'frame, 'metadata>>::type_check(self.specs.as_slice())
64+
}
5565
}
5666

5767
impl<'frame, 'metadata> Iterator for ColumnIterator<'frame, 'metadata> {
5868
type Item = Result<RawColumn<'frame, 'metadata>, DeserializationError>;
5969

6070
#[inline]
6171
fn next(&mut self) -> Option<Self::Item> {
62-
let (column_index, spec) = self.specs.next()?;
72+
let spec = self.specs.next()?;
73+
let column_index = self
74+
.index
75+
.next()
76+
.expect("RangeFrom<usize> iterator exhausted: this indicates usize overflow (more than usize::MAX columns), which should be impossible in practice");
6377
Some(
6478
self.slice
6579
.read_cql_bytes()

scylla/src/client/pager.rs

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -615,36 +615,40 @@ pub struct QueryPager {
615615
impl QueryPager {
616616
/// Returns the next item (`ColumnIterator`) from the stream.
617617
///
618-
/// This can be used with `type_check() for manual deserialization - see example below.
618+
/// Because pages may have different result metadata, each one needs to be type-checked before deserialization.
619+
/// The bool returned in second element of the tuple indicates whether the page was fresh or not.
620+
/// This allows user to then perform the type check for fresh pages.
619621
///
620622
/// This is not a part of the `Stream` interface because the returned iterator
621623
/// borrows from self.
622624
///
623625
/// This is cancel-safe.
624-
async fn next(&mut self) -> Option<Result<ColumnIterator<'_, '_>, NextRowError>> {
626+
async fn next(&mut self) -> Option<Result<(ColumnIterator<'_, '_>, bool), NextRowError>> {
625627
let res = std::future::poll_fn(|cx| Pin::new(&mut *self).poll_fill_page(cx)).await;
626-
match res {
627-
Some(Ok(())) => {}
628+
let fresh_page = match res {
629+
Some(Ok(f)) => f,
628630
Some(Err(err)) => return Some(Err(err)),
629631
None => return None,
630-
}
632+
};
631633

632634
// We are guaranteed here to have a non-empty page, so unwrap
633635
Some(
634636
self.current_page
635637
.next()
636638
.unwrap()
637-
.map_err(NextRowError::RowDeserializationError),
639+
.map_err(NextRowError::RowDeserializationError)
640+
.map(|x| (x, fresh_page)),
638641
)
639642
}
640643

641644
/// Tries to acquire a non-empty page, if current page is exhausted.
645+
/// Boolean value in `Some(Ok(r))` is true if a new page was fetched.
642646
fn poll_fill_page(
643647
mut self: Pin<&mut Self>,
644648
cx: &mut Context<'_>,
645-
) -> Poll<Option<Result<(), NextRowError>>> {
649+
) -> Poll<Option<Result<bool, NextRowError>>> {
646650
if !self.is_current_page_exhausted() {
647-
return Poll::Ready(Some(Ok(())));
651+
return Poll::Ready(Some(Ok(false)));
648652
}
649653
ready_some_ok!(self.as_mut().poll_next_page(cx));
650654
if self.is_current_page_exhausted() {
@@ -653,7 +657,7 @@ impl QueryPager {
653657
cx.waker().wake_by_ref();
654658
Poll::Pending
655659
} else {
656-
Poll::Ready(Some(Ok(())))
660+
Poll::Ready(Some(Ok(true)))
657661
}
658662
}
659663

@@ -691,6 +695,12 @@ impl QueryPager {
691695
/// This is automatically called upon transforming [QueryPager] into [TypedRowStream].
692696
// Can be used with `next()` for manual deserialization.
693697
#[inline]
698+
#[deprecated(
699+
since = "1.4.0",
700+
note = "Type check should be performed for each page, which is not possible with public API.
701+
Also, the only thing user can do (rows_stream) will take care of type check anyway.
702+
If you are using this API, you are probably doing something wrong."
703+
)]
694704
pub fn type_check<'frame, 'metadata, RowT: DeserializeRow<'frame, 'metadata>>(
695705
&self,
696706
) -> Result<(), TypeCheckError> {
@@ -1040,6 +1050,7 @@ impl QueryPager {
10401050
/// To use [Stream] API (only accessible for owned types), use [QueryPager::rows_stream].
10411051
pub struct TypedRowStream<RowT> {
10421052
raw_row_lending_stream: QueryPager,
1053+
current_page_typechecked: bool,
10431054
_phantom: std::marker::PhantomData<RowT>,
10441055
}
10451056

@@ -1061,10 +1072,12 @@ where
10611072
RowT: for<'frame, 'metadata> DeserializeRow<'frame, 'metadata>,
10621073
{
10631074
fn new(raw_stream: QueryPager) -> Result<Self, TypeCheckError> {
1075+
#[allow(deprecated)] // In TypedRowStream we take care to type check each page.
10641076
raw_stream.type_check::<RowT>()?;
10651077

10661078
Ok(Self {
10671079
raw_row_lending_stream: raw_stream,
1080+
current_page_typechecked: true,
10681081
_phantom: Default::default(),
10691082
})
10701083
}
@@ -1101,8 +1114,18 @@ where
11011114

11021115
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
11031116
let next_fut = async {
1104-
self.raw_row_lending_stream.next().await.map(|res| {
1105-
res.and_then(|column_iterator| {
1117+
let real_self: &mut Self = &mut self; // Self is Unpin, and this lets us perform partial borrows.
1118+
real_self.raw_row_lending_stream.next().await.map(|res| {
1119+
res.and_then(|(column_iterator, fresh_page)| {
1120+
if fresh_page {
1121+
real_self.current_page_typechecked = false;
1122+
}
1123+
if !real_self.current_page_typechecked {
1124+
column_iterator.type_check::<RowT>().map_err(|e| {
1125+
NextRowError::NextPageError(NextPageError::TypeCheckError(e))
1126+
})?;
1127+
real_self.current_page_typechecked = true;
1128+
}
11061129
<RowT as DeserializeRow>::deserialize(column_iterator)
11071130
.map_err(NextRowError::RowDeserializationError)
11081131
})
@@ -1130,6 +1153,10 @@ pub enum NextPageError {
11301153
/// Failed to deserialize result metadata associated with next page response.
11311154
#[error("Failed to deserialize result metadata associated with next page response: {0}")]
11321155
ResultMetadataParseError(#[from] ResultMetadataAndRowsCountParseError),
1156+
1157+
/// Failed to type check a received page.
1158+
#[error("Failed to type check a received page: {0}")]
1159+
TypeCheckError(#[from] TypeCheckError),
11331160
}
11341161

11351162
/// An error returned by async iterator API.

scylla/tests/integration/session/pager.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use std::sync::{
33
atomic::{AtomicBool, Ordering},
44
};
55

6+
use assert_matches::assert_matches;
67
use futures::{StreamExt as _, TryStreamExt as _};
8+
use scylla::errors::{NextPageError, NextRowError};
79
use scylla::{
810
client::execution_profile::ExecutionProfile,
911
policies::retry::{RequestInfo, RetryDecision, RetryPolicy, RetrySession},
@@ -149,3 +151,72 @@ async fn test_iter_methods_with_modification_statements() {
149151

150152
session.ddl(format!("DROP KEYSPACE {ks}")).await.unwrap();
151153
}
154+
155+
// Regression test for https://github.com/scylladb/scylla-rust-driver/issues/1448
156+
// PR with fix: https://github.com/scylladb/scylla-rust-driver/pull/1449
157+
#[tokio::test]
158+
async fn test_iter_methods_when_altering_table() {
159+
let session = create_new_session_builder().build().await.unwrap();
160+
let ks = unique_keyspace_name();
161+
162+
session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}")).await.unwrap();
163+
session
164+
.ddl(format!(
165+
"CREATE TABLE IF NOT EXISTS {ks}.t (a int, b int, d int, primary key (a, b))"
166+
))
167+
.await
168+
.unwrap();
169+
170+
let insert_stmt = session
171+
.prepare(format!("INSERT INTO {ks}.t (a, b, d) VALUES (?, ?, ?)"))
172+
.await
173+
.unwrap();
174+
// First let's insert some data
175+
for a in 0..10 {
176+
for b in 0..10 {
177+
session
178+
.execute_unpaged(&insert_stmt, (a, b, 1337))
179+
.await
180+
.unwrap();
181+
}
182+
}
183+
184+
let mut select_stmt = session
185+
.prepare(format!("SELECT * FROM {ks}.t",))
186+
.await
187+
.unwrap();
188+
select_stmt.set_page_size(10);
189+
select_stmt.set_use_cached_result_metadata(false);
190+
let pager = session.execute_iter(select_stmt, &[]).await.unwrap();
191+
let mut stream = pager.rows_stream::<(i32, i32, Option<i32>)>().unwrap();
192+
193+
// Let's fetch a few pages, but not all.
194+
for _ in 0..50 {
195+
let _row = stream.next().await.unwrap().unwrap();
196+
}
197+
198+
session
199+
.query_unpaged(format!("ALTER TABLE {ks}.t ADD c text"), &())
200+
.await
201+
.unwrap();
202+
203+
// With the bug (typecheck only being done for first page), the code panics!
204+
// At some point, requests will return pages with new schema.
205+
// It contains new column, and the new schema was not type checked.
206+
// DeserializeRow::deserialize impl will panic because invariants that should
207+
// be enforced by type check are violated.
208+
let err = loop {
209+
match stream.next().await {
210+
None => panic!("No error. Expected typecheck error."),
211+
Some(Ok(_row)) => continue,
212+
Some(Err(e)) => break e,
213+
}
214+
};
215+
216+
assert_matches!(
217+
err,
218+
NextRowError::NextPageError(NextPageError::TypeCheckError(_))
219+
);
220+
221+
session.ddl(format!("DROP KEYSPACE {ks}")).await.unwrap();
222+
}

0 commit comments

Comments
 (0)