Skip to content

Commit a343677

Browse files
authored
feat: Minimal support for extension array casting (#3851)
I think this is the most minimal example of extension array casting that's easy and safe to execute? Basically just handles casting within the same extension dtype, changing nullability when appropriate. Actually casting the underlying storage array seems dangerous given that we don't know how its actually used. Signed-off-by: Adam Gutglick <[email protected]>
1 parent 42afa5e commit a343677

File tree

12 files changed

+192
-65
lines changed

12 files changed

+192
-65
lines changed

encodings/datetime-parts/src/compute/cast.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,28 @@
44
use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
55
use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
66
use vortex_dtype::DType;
7-
use vortex_error::{VortexResult, vortex_bail};
7+
use vortex_error::VortexResult;
88

99
use crate::{DateTimePartsArray, DateTimePartsVTable};
1010

1111
impl CastKernel for DateTimePartsVTable {
12-
fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult<ArrayRef> {
12+
fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1313
if !array.dtype().eq_ignore_nullability(dtype) {
14-
vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype);
14+
return Ok(None);
1515
};
1616

17-
Ok(DateTimePartsArray::try_new(
18-
dtype.clone(),
19-
cast(
20-
array.days().as_ref(),
21-
&array.days().dtype().with_nullability(dtype.nullability()),
22-
)?,
23-
array.seconds().clone(),
24-
array.subseconds().clone(),
25-
)?
26-
.into_array())
17+
Ok(Some(
18+
DateTimePartsArray::try_new(
19+
dtype.clone(),
20+
cast(
21+
array.days().as_ref(),
22+
&array.days().dtype().with_nullability(dtype.nullability()),
23+
)?,
24+
array.seconds().clone(),
25+
array.subseconds().clone(),
26+
)?
27+
.into_array(),
28+
))
2729
}
2830
}
2931

@@ -86,10 +88,10 @@ mod tests {
8688
let array = date_time_array(validity);
8789
let result = cast(&array, &DType::Bool(Nullability::NonNullable));
8890
assert!(
89-
result
90-
.as_ref()
91-
.is_err_and(|err| err.to_string().contains("cannot cast from")),
92-
"{result:?}"
91+
result.as_ref().is_err_and(|err| err.to_string().contains(
92+
"No compute kernel to cast array vortex.ext with dtype ext(vortex.timestamp, i64, ExtMetadata([2, 3, 0, 85, 84, 67]))? to bool"
93+
)),
94+
"Got error: {result:?}"
9395
);
9496

9597
let result = cast(
@@ -100,7 +102,7 @@ mod tests {
100102
result.as_ref().is_err_and(|err| err
101103
.to_string()
102104
.contains("invalid cast from nullable to non-nullable")),
103-
"{result:?}"
105+
"Got error: {result:?}"
104106
);
105107
}
106108
}

vortex-array/src/arrays/bool/compute/cast.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_dtype::DType;
5-
use vortex_error::{VortexResult, vortex_bail};
5+
use vortex_error::VortexResult;
66

77
use crate::array::ArrayRef;
88
use crate::arrays::{BoolArray, BoolVTable};
@@ -11,14 +11,16 @@ use crate::register_kernel;
1111
use crate::vtable::ValidityHelper;
1212

1313
impl CastKernel for BoolVTable {
14-
fn cast(&self, array: &BoolArray, dtype: &DType) -> VortexResult<ArrayRef> {
14+
fn cast(&self, array: &BoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1515
if !matches!(dtype, DType::Bool(_)) {
16-
vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype);
16+
return Ok(None);
1717
}
1818

1919
let new_nullability = dtype.nullability();
2020
let new_validity = array.validity().clone().cast_nullability(new_nullability)?;
21-
Ok(BoolArray::new(array.boolean_buffer().clone(), new_validity).to_array())
21+
Ok(Some(
22+
BoolArray::new(array.boolean_buffer().clone(), new_validity).to_array(),
23+
))
2224
}
2325
}
2426

vortex-array/src/arrays/chunked/compute/cast.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ use crate::compute::{CastKernel, CastKernelAdapter, cast};
99
use crate::{ArrayRef, IntoArray, register_kernel};
1010

1111
impl CastKernel for ChunkedVTable {
12-
fn cast(&self, array: &ChunkedArray, dtype: &DType) -> VortexResult<ArrayRef> {
12+
fn cast(&self, array: &ChunkedArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1313
let mut cast_chunks = Vec::new();
1414
for chunk in array.chunks() {
1515
cast_chunks.push(cast(chunk, dtype)?);
1616
}
1717

18-
Ok(ChunkedArray::new_unchecked(cast_chunks, dtype.clone()).into_array())
18+
Ok(Some(
19+
ChunkedArray::new_unchecked(cast_chunks, dtype.clone()).into_array(),
20+
))
1921
}
2022
}
2123

vortex-array/src/arrays/constant/compute/cast.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ use crate::compute::{CastKernel, CastKernelAdapter};
99
use crate::{ArrayRef, IntoArray, register_kernel};
1010

1111
impl CastKernel for ConstantVTable {
12-
fn cast(&self, array: &ConstantArray, dtype: &DType) -> VortexResult<ArrayRef> {
13-
Ok(ConstantArray::new(array.scalar().cast(dtype)?, array.len()).into_array())
12+
fn cast(&self, array: &ConstantArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13+
match array.scalar().cast(dtype) {
14+
Ok(scalar) => Ok(Some(ConstantArray::new(scalar, array.len()).into_array())),
15+
Err(_e) => Ok(None),
16+
}
1417
}
1518
}
1619

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
6+
use crate::arrays::{ExtensionArray, ExtensionVTable};
7+
use crate::compute::{CastKernel, CastKernelAdapter, cast};
8+
use crate::{ArrayRef, IntoArray, register_kernel};
9+
10+
impl CastKernel for ExtensionVTable {
11+
fn cast(
12+
&self,
13+
array: &ExtensionArray,
14+
dtype: &DType,
15+
) -> vortex_error::VortexResult<Option<ArrayRef>> {
16+
if !array.dtype().eq_ignore_nullability(dtype) {
17+
return Ok(None);
18+
}
19+
20+
let DType::Extension(ext_dtype) = dtype else {
21+
unreachable!("Already verified we have an extension dtype");
22+
};
23+
24+
let new_storage = match cast(array.storage(), ext_dtype.storage_dtype()) {
25+
Ok(arr) => arr,
26+
Err(e) => {
27+
log::warn!("Failed to cast storage array: {e}");
28+
return Ok(None);
29+
}
30+
};
31+
32+
Ok(Some(
33+
ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
34+
))
35+
}
36+
}
37+
38+
register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
39+
40+
#[cfg(test)]
41+
mod tests {
42+
use std::sync::Arc;
43+
44+
use vortex_dtype::datetime::{TIMESTAMP_ID, TemporalMetadata, TimeUnit};
45+
use vortex_dtype::{ExtDType, Nullability, PType};
46+
47+
use super::*;
48+
use crate::arrays::PrimitiveArray;
49+
50+
#[test]
51+
fn cast_same_ext_dtype() {
52+
let ext_dtype = Arc::new(ExtDType::new(
53+
TIMESTAMP_ID.clone(),
54+
Arc::new(PType::I64.into()),
55+
Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()),
56+
));
57+
let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
58+
59+
let arr = ExtensionArray::new(ext_dtype.clone(), storage);
60+
61+
let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
62+
assert_eq!(arr.len(), output.len());
63+
assert_eq!(arr.dtype(), output.dtype());
64+
assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
65+
}
66+
67+
#[test]
68+
fn cast_same_ext_dtype_differet_nullability() {
69+
let ext_dtype = Arc::new(ExtDType::new(
70+
TIMESTAMP_ID.clone(),
71+
Arc::new(PType::I64.into()),
72+
Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()),
73+
));
74+
let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
75+
76+
let arr = ExtensionArray::new(ext_dtype.clone(), storage);
77+
assert!(!arr.dtype.is_nullable());
78+
79+
let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
80+
81+
let output = cast(arr.as_ref(), &new_dtype).unwrap();
82+
assert_eq!(arr.len(), output.len());
83+
assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
84+
assert_eq!(output.dtype(), &new_dtype);
85+
}
86+
87+
#[test]
88+
fn cast_different_ext_dtype() {
89+
let original_dtype = Arc::new(ExtDType::new(
90+
TIMESTAMP_ID.clone(),
91+
Arc::new(PType::I64.into()),
92+
Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()),
93+
));
94+
let target_dtype = Arc::new(ExtDType::new(
95+
TIMESTAMP_ID.clone(),
96+
Arc::new(PType::I64.into()),
97+
// Note NS here instead of MS
98+
Some(TemporalMetadata::Timestamp(TimeUnit::Ns, None).into()),
99+
));
100+
101+
let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
102+
let arr = ExtensionArray::new(original_dtype, storage);
103+
104+
assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
105+
}
106+
}

vortex-array/src/arrays/extension/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
mod cast;
45
mod compare;
56

67
use vortex_error::VortexResult;

vortex-array/src/arrays/list/compute/cast.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_dtype::DType;
5-
use vortex_error::{VortexResult, vortex_bail};
5+
use vortex_error::VortexResult;
66

77
use crate::arrays::{ListArray, ListVTable};
88
use crate::compute::{CastKernel, CastKernelAdapter, cast};
99
use crate::vtable::ValidityHelper;
1010
use crate::{ArrayRef, register_kernel};
1111

1212
impl CastKernel for ListVTable {
13-
fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<ArrayRef> {
13+
fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1414
let Some(target_element_type) = dtype.as_list_element() else {
15-
vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
15+
return Ok(None);
1616
};
1717

1818
let validity = array
@@ -25,7 +25,7 @@ impl CastKernel for ListVTable {
2525
array.offsets().clone(),
2626
validity,
2727
)
28-
.map(|a| a.to_array())
28+
.map(|a| Some(a.to_array()))
2929
}
3030
}
3131

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ use crate::vtable::ValidityHelper;
1313
use crate::{ArrayRef, IntoArray, register_kernel};
1414

1515
impl CastKernel for PrimitiveVTable {
16-
fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<ArrayRef> {
16+
fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1717
let DType::Primitive(new_ptype, new_nullability) = dtype else {
18-
vortex_bail!(MismatchedTypes: "primitive type", dtype);
18+
return Ok(None);
1919
};
2020
let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
2121

@@ -36,17 +36,21 @@ impl CastKernel for PrimitiveVTable {
3636

3737
// If the bit width is the same, we can short-circuit and simply update the validity
3838
if array.ptype() == new_ptype {
39-
return Ok(PrimitiveArray::from_byte_buffer(
40-
array.byte_buffer().clone(),
41-
array.ptype(),
42-
new_validity,
43-
)
44-
.into_array());
39+
return Ok(Some(
40+
PrimitiveArray::from_byte_buffer(
41+
array.byte_buffer().clone(),
42+
array.ptype(),
43+
new_validity,
44+
)
45+
.into_array(),
46+
));
4547
}
4648

4749
// Otherwise, we need to cast the values one-by-one
4850
match_each_native_ptype!(new_ptype, |T| {
49-
Ok(PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array())
51+
Ok(Some(
52+
PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array(),
53+
))
5054
})
5155
}
5256
}

vortex-array/src/arrays/struct_/compute/cast.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ use crate::vtable::ValidityHelper;
1111
use crate::{ArrayRef, IntoArray, register_kernel};
1212

1313
impl CastKernel for StructVTable {
14-
fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<ArrayRef> {
14+
fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1515
let Some(target_sdtype) = dtype.as_struct() else {
16-
vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
16+
return Ok(None);
1717
};
1818

1919
let source_sdtype = array
@@ -41,7 +41,7 @@ impl CastKernel for StructVTable {
4141
array.len(),
4242
validity,
4343
)
44-
.map(|a| a.into_array())
44+
.map(|a| Some(a.into_array()))
4545
}
4646
}
4747

vortex-array/src/arrays/varbin/compute/cast.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,31 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_dtype::DType;
5-
use vortex_error::{VortexResult, vortex_bail};
5+
use vortex_error::VortexResult;
66

77
use crate::arrays::{VarBinArray, VarBinVTable};
88
use crate::compute::{CastKernel, CastKernelAdapter};
99
use crate::vtable::ValidityHelper;
1010
use crate::{ArrayRef, IntoArray, register_kernel};
1111

1212
impl CastKernel for VarBinVTable {
13-
fn cast(&self, array: &VarBinArray, dtype: &DType) -> VortexResult<ArrayRef> {
13+
fn cast(&self, array: &VarBinArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1414
if !array.dtype().eq_ignore_nullability(dtype) {
15-
vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype);
15+
return Ok(None);
1616
}
1717

1818
let new_nullability = dtype.nullability();
1919
let new_validity = array.validity().clone().cast_nullability(new_nullability)?;
2020
let new_dtype = array.dtype().with_nullability(new_nullability);
21-
Ok(VarBinArray::try_new(
22-
array.offsets().clone(),
23-
array.bytes().clone(),
24-
new_dtype,
25-
new_validity,
26-
)?
27-
.into_array())
21+
Ok(Some(
22+
VarBinArray::try_new(
23+
array.offsets().clone(),
24+
array.bytes().clone(),
25+
new_dtype,
26+
new_validity,
27+
)?
28+
.into_array(),
29+
))
2830
}
2931
}
3032

0 commit comments

Comments
 (0)