Skip to content

Commit d647ffa

Browse files
committed
add cast infra
Signed-off-by: Connor Tsui <[email protected]>
1 parent e3eabde commit d647ffa

File tree

15 files changed

+690
-56
lines changed

15 files changed

+690
-56
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
use vortex_error::VortexResult;
6+
use vortex_error::vortex_bail;
7+
use vortex_vector::Scalar;
8+
use vortex_vector::ScalarOps;
9+
use vortex_vector::Vector;
10+
use vortex_vector::VectorOps;
11+
use vortex_vector::binaryview::BinaryViewScalar;
12+
use vortex_vector::binaryview::BinaryViewType;
13+
use vortex_vector::binaryview::BinaryViewVector;
14+
15+
use crate::cast::Cast;
16+
use crate::cast::try_cast_scalar_common;
17+
use crate::cast::try_cast_vector_common;
18+
19+
impl<T: BinaryViewType> Cast for BinaryViewVector<T> {
20+
type Output = Vector;
21+
22+
/// Casts to Utf8 or Binary (identity cast with compatible nullability).
23+
fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
24+
if let Some(result) = try_cast_vector_common(self, target_dtype)? {
25+
return Ok(result);
26+
}
27+
28+
match target_dtype {
29+
// Identity cast: same type with compatible nullability.
30+
dt if T::matches_dtype(dt) && (dt.is_nullable() || self.validity().all_true()) => {
31+
Ok(self.clone().into())
32+
}
33+
// Cross-cast between Utf8 and Binary is not supported.
34+
DType::Utf8(_) | DType::Binary(_) => {
35+
vortex_bail!(
36+
"Cannot cast BinaryViewVector to {} (cross-cast not supported)",
37+
target_dtype
38+
);
39+
}
40+
_ => {
41+
vortex_bail!("Cannot cast BinaryViewVector to {}", target_dtype);
42+
}
43+
}
44+
}
45+
}
46+
47+
impl<T: BinaryViewType> Cast for BinaryViewScalar<T> {
48+
type Output = Scalar;
49+
50+
/// Casts to Utf8 or Binary (identity cast with compatible nullability).
51+
fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
52+
if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
53+
return Ok(result);
54+
}
55+
56+
match target_dtype {
57+
// Identity cast: same type with compatible nullability.
58+
dt if T::matches_dtype(dt) && (dt.is_nullable() || self.is_valid()) => {
59+
Ok(self.clone().into())
60+
}
61+
// Cross-cast between Utf8 and Binary is not supported.
62+
DType::Utf8(_) | DType::Binary(_) => {
63+
vortex_bail!(
64+
"Cannot cast BinaryViewScalar to {} (cross-cast not supported)",
65+
target_dtype
66+
);
67+
}
68+
_ => {
69+
vortex_bail!("Cannot cast BinaryViewScalar to {}", target_dtype);
70+
}
71+
}
72+
}
73+
}

vortex-compute/src/cast/bool.rs

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,33 @@ use vortex_dtype::DType;
88
use vortex_dtype::match_each_native_ptype;
99
use vortex_error::VortexResult;
1010
use vortex_error::vortex_bail;
11+
use vortex_vector::Scalar;
12+
use vortex_vector::ScalarOps;
1113
use vortex_vector::Vector;
1214
use vortex_vector::VectorOps;
15+
use vortex_vector::bool::BoolScalar;
1316
use vortex_vector::bool::BoolVector;
14-
use vortex_vector::null::NullVector;
17+
use vortex_vector::primitive::PScalar;
1518
use vortex_vector::primitive::PVector;
1619

1720
use crate::cast::Cast;
21+
use crate::cast::try_cast_scalar_common;
22+
use crate::cast::try_cast_vector_common;
1823

1924
impl Cast for BoolVector {
2025
type Output = Vector;
2126

22-
fn cast(&self, dtype: &DType) -> VortexResult<Vector> {
23-
match dtype {
24-
DType::Null if self.validity().all_false() => {
25-
// Can cast an all-null BoolVector to NullVector.
26-
Ok(NullVector::new(self.len()).into())
27-
}
27+
/// Casts to Bool (identity) or Primitive (as 0/1).
28+
fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
29+
if let Some(result) = try_cast_vector_common(self, target_dtype)? {
30+
return Ok(result);
31+
}
32+
33+
match target_dtype {
2834
DType::Bool(n) if n.is_nullable() || self.validity().all_true() => {
29-
// If the target dtype is nullable, or if the source BoolVector has no nulls,
30-
// we can cast directly to BoolVector.
3135
Ok(self.clone().into())
3236
}
33-
DType::Primitive(ptype, _) => {
37+
DType::Primitive(ptype, n) if n.is_nullable() || self.validity().all_true() => {
3438
match_each_native_ptype!(ptype, |T| {
3539
Ok(PVector::<T>::new(
3640
Buffer::<T>::from_trusted_len_iter(
@@ -43,9 +47,31 @@ impl Cast for BoolVector {
4347
.into())
4448
})
4549
}
46-
DType::Extension(ext_dtype) => self.cast(ext_dtype.storage_dtype()),
4750
_ => {
48-
vortex_bail!("Cannot cast BoolVector to type {}", dtype);
51+
vortex_bail!("Cannot cast BoolVector to {}", target_dtype);
52+
}
53+
}
54+
}
55+
}
56+
57+
impl Cast for BoolScalar {
58+
type Output = Scalar;
59+
60+
/// Casts to Bool (identity) or Primitive (as 0/1).
61+
fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
62+
if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
63+
return Ok(result);
64+
}
65+
match target_dtype {
66+
DType::Bool(n) if n.is_nullable() || self.is_valid() => Ok(self.clone().into()),
67+
DType::Primitive(ptype, n) if n.is_nullable() || self.is_valid() => {
68+
match_each_native_ptype!(ptype, |T| {
69+
let value = self.value().map(|b| if b { T::one() } else { T::zero() });
70+
Ok(PScalar::<T>::new(value).into())
71+
})
72+
}
73+
_ => {
74+
vortex_bail!("Cannot cast BoolScalar to {}", target_dtype);
4975
}
5076
}
5177
}

vortex-compute/src/cast/decimal.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
use vortex_error::VortexResult;
6+
use vortex_vector::Scalar;
7+
use vortex_vector::Vector;
8+
use vortex_vector::decimal::DecimalScalar;
9+
use vortex_vector::decimal::DecimalVector;
10+
use vortex_vector::match_each_dscalar;
11+
use vortex_vector::match_each_dvector;
12+
13+
use crate::cast::Cast;
14+
15+
impl Cast for DecimalVector {
16+
type Output = Vector;
17+
18+
/// Dispatches to the underlying [`DVector<D>`](vortex_vector::decimal::DVector) implementation.
19+
fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
20+
match_each_dvector!(self, |v| { Cast::cast(v, target_dtype) })
21+
}
22+
}
23+
24+
impl Cast for DecimalScalar {
25+
type Output = Scalar;
26+
27+
/// Dispatches to the underlying [`DScalar<D>`](vortex_vector::decimal::DScalar) implementation.
28+
fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
29+
match_each_dscalar!(self, |s| { Cast::cast(s, target_dtype) })
30+
}
31+
}

vortex-compute/src/cast/dvector.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
use vortex_dtype::NativeDecimalType;
6+
use vortex_error::VortexResult;
7+
use vortex_error::vortex_bail;
8+
use vortex_vector::Scalar;
9+
use vortex_vector::ScalarOps;
10+
use vortex_vector::Vector;
11+
use vortex_vector::VectorOps;
12+
use vortex_vector::decimal::DScalar;
13+
use vortex_vector::decimal::DVector;
14+
15+
use crate::cast::Cast;
16+
use crate::cast::try_cast_scalar_common;
17+
use crate::cast::try_cast_vector_common;
18+
19+
impl<D: NativeDecimalType> Cast for DVector<D> {
20+
type Output = Vector;
21+
22+
/// Casts to Decimal (identity with same precision/scale and compatible nullability).
23+
fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
24+
if let Some(result) = try_cast_vector_common(self, target_dtype)? {
25+
return Ok(result);
26+
}
27+
28+
match target_dtype {
29+
// Identity cast: same precision, scale, and compatible nullability.
30+
DType::Decimal(ddt, n)
31+
if ddt.precision() == self.precision()
32+
&& ddt.scale() == self.scale()
33+
&& (n.is_nullable() || self.validity().all_true()) =>
34+
{
35+
Ok(self.clone().into())
36+
}
37+
// TODO(connor): cast to different precision/scale
38+
DType::Decimal(..) => {
39+
vortex_bail!(
40+
"Casting DVector to {} with different precision/scale not yet implemented",
41+
target_dtype
42+
);
43+
}
44+
_ => {
45+
vortex_bail!("Cannot cast DVector to {}", target_dtype);
46+
}
47+
}
48+
}
49+
}
50+
51+
impl<D: NativeDecimalType> Cast for DScalar<D> {
52+
type Output = Scalar;
53+
54+
/// Casts to Decimal (identity with same precision/scale and compatible nullability).
55+
fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
56+
if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
57+
return Ok(result);
58+
}
59+
60+
match target_dtype {
61+
// Identity cast: same precision, scale, and compatible nullability.
62+
DType::Decimal(ddt, n)
63+
if ddt.precision() == self.precision()
64+
&& ddt.scale() == self.scale()
65+
&& (n.is_nullable() || self.is_valid()) =>
66+
{
67+
Ok(self.clone().into())
68+
}
69+
// TODO(connor): cast to different precision/scale
70+
DType::Decimal(..) => {
71+
vortex_bail!(
72+
"Casting DScalar to {} with different precision/scale not yet implemented",
73+
target_dtype
74+
);
75+
}
76+
_ => {
77+
vortex_bail!("Cannot cast DScalar to {}", target_dtype);
78+
}
79+
}
80+
}
81+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
use vortex_error::VortexResult;
6+
use vortex_error::vortex_bail;
7+
use vortex_vector::Scalar;
8+
use vortex_vector::ScalarOps;
9+
use vortex_vector::Vector;
10+
use vortex_vector::VectorOps;
11+
use vortex_vector::fixed_size_list::FixedSizeListScalar;
12+
use vortex_vector::fixed_size_list::FixedSizeListVector;
13+
use vortex_vector::vector_matches_dtype;
14+
15+
use crate::cast::Cast;
16+
use crate::cast::try_cast_scalar_common;
17+
use crate::cast::try_cast_vector_common;
18+
19+
impl Cast for FixedSizeListVector {
20+
type Output = Vector;
21+
22+
/// Casts to FixedSizeList (identity with same element dtype, size, and compatible nullability).
23+
fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
24+
if let Some(result) = try_cast_vector_common(self, target_dtype)? {
25+
return Ok(result);
26+
}
27+
28+
match target_dtype {
29+
// Identity cast: same element dtype, size, and compatible nullability.
30+
DType::FixedSizeList(element_dtype, size, n)
31+
if *size == self.list_size()
32+
&& vector_matches_dtype(self.elements(), element_dtype)
33+
&& (n.is_nullable() || self.validity().all_true()) =>
34+
{
35+
Ok(self.clone().into())
36+
}
37+
DType::FixedSizeList(..) => {
38+
vortex_bail!(
39+
"Cannot cast FixedSizeListVector to {} (incompatible element type or size)",
40+
target_dtype
41+
);
42+
}
43+
_ => {
44+
vortex_bail!("Cannot cast FixedSizeListVector to {}", target_dtype);
45+
}
46+
}
47+
}
48+
}
49+
50+
impl Cast for FixedSizeListScalar {
51+
type Output = Scalar;
52+
53+
/// Casts to FixedSizeList (identity with same element dtype, size, and compatible nullability).
54+
fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
55+
if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
56+
return Ok(result);
57+
}
58+
59+
match target_dtype {
60+
// Identity cast: same element dtype, size, and compatible nullability.
61+
// We check by verifying the scalar's underlying vector matches the target dtype.
62+
DType::FixedSizeList(element_dtype, size, n)
63+
if *size == self.value().list_size()
64+
&& vector_matches_dtype(self.value().elements(), element_dtype)
65+
&& (n.is_nullable() || self.is_valid()) =>
66+
{
67+
Ok(self.clone().into())
68+
}
69+
DType::FixedSizeList(..) => {
70+
vortex_bail!(
71+
"Cannot cast FixedSizeListScalar to {} (incompatible element type, size, or nullability)",
72+
target_dtype
73+
);
74+
}
75+
_ => {
76+
vortex_bail!("Cannot cast FixedSizeListScalar to {}", target_dtype);
77+
}
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)