Skip to content

Commit 12ceb62

Browse files
authored
Merge pull request #1020 from rukai/parse_vector_type
cassandra 5.0 vector type CREATE/INSERT support
2 parents f59908c + 2685ab9 commit 12ceb62

File tree

5 files changed

+169
-2
lines changed

5 files changed

+169
-2
lines changed

.github/workflows/cassandra.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
run: cargo build --verbose --tests --features "full-serialization"
3232
- name: Run tests on cassandra
3333
run: |
34-
CDC='disabled' RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements
34+
CDC='disabled' RUSTFLAGS="--cfg cassandra_tests" RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements
3535
- name: Stop the cluster
3636
if: ${{ always() }}
3737
run: docker compose -f test/cluster/cassandra/docker-compose.yml stop

scylla/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ harness = false
9898
[lints.rust]
9999
unnameable_types = "warn"
100100
unreachable_pub = "warn"
101-
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)'] }
101+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)', 'cfg(cassandra_tests)'] }

scylla/src/transport/session_test.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3165,3 +3165,108 @@ async fn test_api_migration_session_sharing() {
31653165
assert!(matched);
31663166
}
31673167
}
3168+
3169+
#[cfg(cassandra_tests)]
3170+
#[tokio::test]
3171+
async fn test_vector_type_metadata() {
3172+
setup_tracing();
3173+
let session = create_new_session_builder().build().await.unwrap();
3174+
let ks = unique_keyspace_name();
3175+
3176+
session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
3177+
session
3178+
.query_unpaged(
3179+
format!(
3180+
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
3181+
ks
3182+
),
3183+
&[],
3184+
)
3185+
.await
3186+
.unwrap();
3187+
3188+
session.refresh_metadata().await.unwrap();
3189+
let metadata = session.get_cluster_data();
3190+
let columns = &metadata.keyspaces[&ks].tables["t"].columns;
3191+
assert_eq!(
3192+
columns["b"].type_,
3193+
CqlType::Vector {
3194+
type_: Box::new(CqlType::Native(NativeType::Int)),
3195+
dimensions: 4,
3196+
},
3197+
);
3198+
assert_eq!(
3199+
columns["c"].type_,
3200+
CqlType::Vector {
3201+
type_: Box::new(CqlType::Native(NativeType::Text)),
3202+
dimensions: 2,
3203+
},
3204+
);
3205+
}
3206+
3207+
#[cfg(cassandra_tests)]
3208+
#[tokio::test]
3209+
async fn test_vector_type_unprepared() {
3210+
setup_tracing();
3211+
let session = create_new_session_builder().build().await.unwrap();
3212+
let ks = unique_keyspace_name();
3213+
3214+
session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
3215+
session
3216+
.query_unpaged(
3217+
format!(
3218+
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
3219+
ks
3220+
),
3221+
&[],
3222+
)
3223+
.await
3224+
.unwrap();
3225+
3226+
session
3227+
.query_unpaged(
3228+
format!(
3229+
"INSERT INTO {}.t (a, b, c) VALUES (1, [1, 2, 3, 4], ['foo', 'bar'])",
3230+
ks
3231+
),
3232+
&[],
3233+
)
3234+
.await
3235+
.unwrap();
3236+
3237+
// TODO: Implement and test SELECT statements and bind values (`?`)
3238+
}
3239+
3240+
#[cfg(cassandra_tests)]
3241+
#[tokio::test]
3242+
async fn test_vector_type_prepared() {
3243+
setup_tracing();
3244+
let session = create_new_session_builder().build().await.unwrap();
3245+
let ks = unique_keyspace_name();
3246+
3247+
session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
3248+
session
3249+
.query_unpaged(
3250+
format!(
3251+
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
3252+
ks
3253+
),
3254+
&[],
3255+
)
3256+
.await
3257+
.unwrap();
3258+
3259+
let prepared_statement = session
3260+
.prepare(format!(
3261+
"INSERT INTO {}.t (a, b, c) VALUES (?, [11, 12, 13, 14], ['afoo', 'abar'])",
3262+
ks
3263+
))
3264+
.await
3265+
.unwrap();
3266+
session
3267+
.execute_unpaged(&prepared_statement, &(2,))
3268+
.await
3269+
.unwrap();
3270+
3271+
// TODO: Implement and test SELECT statements and bind values (`?`)
3272+
}

scylla/src/transport/topology.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ enum PreCqlType {
188188
type_: PreCollectionType,
189189
},
190190
Tuple(Vec<PreCqlType>),
191+
Vector {
192+
type_: Box<PreCqlType>,
193+
/// matches the datatype used by the java driver:
194+
/// <https://github.com/apache/cassandra-java-driver/blob/85bb4065098b887d2dda26eb14423ce4fc687045/core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java#L77>
195+
dimensions: i32,
196+
},
191197
UserDefinedType {
192198
frozen: bool,
193199
name: String,
@@ -211,6 +217,10 @@ impl PreCqlType {
211217
.map(|t| t.into_cql_type(keyspace_name, udts))
212218
.collect(),
213219
),
220+
PreCqlType::Vector { type_, dimensions } => CqlType::Vector {
221+
type_: Box::new(type_.into_cql_type(keyspace_name, udts)),
222+
dimensions,
223+
},
214224
PreCqlType::UserDefinedType { frozen, name } => {
215225
let definition = match udts
216226
.get(keyspace_name)
@@ -236,6 +246,12 @@ pub enum CqlType {
236246
type_: CollectionType,
237247
},
238248
Tuple(Vec<CqlType>),
249+
Vector {
250+
type_: Box<CqlType>,
251+
/// matches the datatype used by the java driver:
252+
/// <https://github.com/apache/cassandra-java-driver/blob/85bb4065098b887d2dda26eb14423ce4fc687045/core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java#L77>
253+
dimensions: i32,
254+
},
239255
UserDefinedType {
240256
frozen: bool,
241257
// Using Arc here in order not to have many copies of the same definition
@@ -1137,6 +1153,7 @@ fn topo_sort_udts(udts: &mut Vec<UdtRowWithParsedFieldTypes>) -> Result<(), Quer
11371153
PreCqlType::Tuple(types) => types
11381154
.iter()
11391155
.for_each(|type_| do_with_referenced_udts(what, type_)),
1156+
PreCqlType::Vector { type_, .. } => do_with_referenced_udts(what, type_),
11401157
PreCqlType::UserDefinedType { name, .. } => what(name),
11411158
}
11421159
}
@@ -1637,6 +1654,22 @@ fn parse_cql_type(p: ParserState<'_>) -> ParseResult<(PreCqlType, ParserState<'_
16371654
})?;
16381655

16391656
Ok((PreCqlType::Tuple(types), p))
1657+
} else if let Ok(p) = p.accept("vector<") {
1658+
let (inner_type, p) = parse_cql_type(p)?;
1659+
1660+
let p = p.skip_white();
1661+
let p = p.accept(",")?;
1662+
let p = p.skip_white();
1663+
let (size, p) = p.parse_i32()?;
1664+
let p = p.skip_white();
1665+
let p = p.accept(">")?;
1666+
1667+
let typ = PreCqlType::Vector {
1668+
type_: Box::new(inner_type),
1669+
dimensions: size,
1670+
};
1671+
1672+
Ok((typ, p))
16401673
} else if let Ok((typ, p)) = parse_native_type(p) {
16411674
Ok((PreCqlType::Native(typ), p))
16421675
} else if let Ok((name, p)) = parse_user_defined_type(p) {
@@ -1827,6 +1860,20 @@ mod tests {
18271860
PreCqlType::Native(NativeType::Varint),
18281861
]),
18291862
),
1863+
(
1864+
"vector<int, 5>",
1865+
PreCqlType::Vector {
1866+
type_: Box::new(PreCqlType::Native(NativeType::Int)),
1867+
dimensions: 5,
1868+
},
1869+
),
1870+
(
1871+
"vector<text, 1234>",
1872+
PreCqlType::Vector {
1873+
type_: Box::new(PreCqlType::Native(NativeType::Text)),
1874+
dimensions: 1234,
1875+
},
1876+
),
18301877
(
18311878
"com.scylladb.types.AwesomeType",
18321879
PreCqlType::UserDefinedType {

scylla/src/utils/parse.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ impl<'s> ParserState<'s> {
8787
me
8888
}
8989

90+
/// Parses a sequence of digits and '-' as an integer.
91+
/// Consumes characters until it finds a character that is not a digit or '-'.
92+
///
93+
/// An error is returned if:
94+
/// * The first character is not a digit or '-'
95+
/// * The integer is larger than i32
96+
pub(crate) fn parse_i32(self) -> ParseResult<(i32, Self)> {
97+
let (digits, p) = self.take_while(|c| c.is_ascii_digit() || c == '-');
98+
if let Ok(value) = digits.parse() {
99+
Ok((value, p))
100+
} else {
101+
Err(p.error(ParseErrorCause::Other("Expected 32-bit signed integer")))
102+
}
103+
}
104+
90105
/// Skips characters from the beginning while they satisfy given predicate
91106
/// and returns new parser state which
92107
pub(crate) fn take_while(self, mut pred: impl FnMut(char) -> bool) -> (&'s str, Self) {

0 commit comments

Comments
 (0)