Skip to content

Commit 64b6087

Browse files
authored
chore: all scalars support cast (#1965)
1 parent 961cbf3 commit 64b6087

File tree

12 files changed

+339
-109
lines changed

12 files changed

+339
-109
lines changed

encodings/bytebool/src/stats.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ impl StatisticsVTable<ByteBoolArray> for ByteBoolEncoding {
2323
#[cfg(test)]
2424
mod tests {
2525
use vortex_array::stats::ArrayStatistics;
26-
use vortex_dtype::{DType, Nullability};
27-
use vortex_scalar::Scalar;
2826

2927
use super::*;
3028

@@ -90,14 +88,8 @@ mod tests {
9088
assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap());
9189
assert!(bool_arr.statistics().compute_is_sorted().unwrap());
9290
assert!(bool_arr.statistics().compute_is_constant().unwrap());
93-
assert_eq!(
94-
bool_arr.statistics().compute(Stat::Min).unwrap(),
95-
Scalar::null(DType::Bool(Nullability::Nullable))
96-
);
97-
assert_eq!(
98-
bool_arr.statistics().compute(Stat::Max).unwrap(),
99-
Scalar::null(DType::Bool(Nullability::Nullable))
100-
);
91+
assert_eq!(bool_arr.statistics().compute(Stat::Min), None);
92+
assert_eq!(bool_arr.statistics().compute(Stat::Max), None);
10193
assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1);
10294
assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0);
10395
}

vortex-array/src/array/bool/stats.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ impl BoolStatsAccumulator {
169169
#[cfg(test)]
170170
mod test {
171171
use arrow_buffer::BooleanBuffer;
172-
use vortex_dtype::Nullability::Nullable;
173-
use vortex_dtype::{DType, Nullability};
174-
use vortex_scalar::Scalar;
172+
use vortex_dtype::Nullability;
175173

176174
use crate::array::BoolArray;
177175
use crate::stats::{ArrayStatistics, Stat};
@@ -278,14 +276,8 @@ mod test {
278276
assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap());
279277
assert!(bool_arr.statistics().compute_is_sorted().unwrap());
280278
assert!(bool_arr.statistics().compute_is_constant().unwrap());
281-
assert_eq!(
282-
bool_arr.statistics().compute(Stat::Min).unwrap(),
283-
Scalar::null(DType::Bool(Nullable))
284-
);
285-
assert_eq!(
286-
bool_arr.statistics().compute(Stat::Max).unwrap(),
287-
Scalar::null(DType::Bool(Nullable))
288-
);
279+
assert_eq!(bool_arr.statistics().compute(Stat::Min), None);
280+
assert_eq!(bool_arr.statistics().compute(Stat::Max), None);
289281
assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1);
290282
assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0);
291283
assert_eq!(bool_arr.statistics().compute_null_count().unwrap(), 5);

vortex-array/src/array/primitive/stats.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ impl<T: PStatsType> BitWidthAccumulator<T> {
334334

335335
#[cfg(test)]
336336
mod test {
337-
use vortex_dtype::{DType, Nullability, PType};
338337
use vortex_scalar::Scalar;
339338

340339
use crate::array::primitive::PrimitiveArray;
@@ -402,8 +401,7 @@ mod test {
402401
let arr = PrimitiveArray::from_option_iter([Option::<i32>::None, None, None]);
403402
let min: Option<Scalar> = arr.statistics().compute(Stat::Min);
404403
let max: Option<Scalar> = arr.statistics().compute(Stat::Max);
405-
let null_i32 = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
406-
assert_eq!(min, Some(null_i32.clone()));
407-
assert_eq!(max, Some(null_i32));
404+
assert_eq!(min, None);
405+
assert_eq!(max, None);
408406
}
409407
}

vortex-array/src/data/statistics.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33
use enum_iterator::all;
44
use itertools::Itertools;
55
use vortex_dtype::{DType, Nullability, PType};
6-
use vortex_error::vortex_panic;
6+
use vortex_error::{vortex_panic, VortexExpect as _};
77
use vortex_scalar::{Scalar, ScalarValue};
88

99
use crate::data::InnerArrayData;
@@ -118,7 +118,7 @@ impl Statistics for ArrayData {
118118
let s = self
119119
.encoding()
120120
.compute_statistics(self, stat)
121-
.ok()?
121+
.vortex_expect("compute_statistics must not fail")
122122
.get(stat)
123123
.cloned();
124124

vortex-array/src/stats/statsset.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ impl StatsSet {
3232
/// an array consisting entirely of [null](vortex_dtype::DType::Null) values.
3333
pub fn nulls(len: usize, dtype: &DType) -> Self {
3434
let mut stats = Self::new_unchecked(vec![
35-
(Stat::Min, Scalar::null(dtype.clone())),
36-
(Stat::Max, Scalar::null(dtype.clone())),
3735
(Stat::RunCount, 1.into()),
3836
(Stat::NullCount, len.into()),
3937
]);
@@ -85,8 +83,10 @@ impl StatsSet {
8583
stats.set(Stat::TrueCount, true_count);
8684
}
8785

88-
stats.set(Stat::Min, scalar.clone());
89-
stats.set(Stat::Max, scalar.clone());
86+
if !scalar.is_null() {
87+
stats.set(Stat::Min, scalar.clone());
88+
stats.set(Stat::Max, scalar.clone());
89+
}
9090

9191
stats
9292
}

vortex-scalar/src/binary.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use vortex_buffer::ByteBuffer;
22
use vortex_dtype::{DType, Nullability};
3-
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
3+
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult};
44

55
use crate::value::{InnerScalarValue, ScalarValue};
66
use crate::Scalar;
@@ -20,8 +20,19 @@ impl<'a> BinaryScalar<'a> {
2020
self.value.as_ref().cloned()
2121
}
2222

23-
pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
24-
todo!()
23+
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
24+
if !matches!(dtype, DType::Binary(..)) {
25+
vortex_bail!("Can't cast binary to {}", dtype)
26+
}
27+
Ok(Scalar::new(
28+
dtype.clone(),
29+
ScalarValue(InnerScalarValue::Buffer(
30+
self.value
31+
.as_ref()
32+
.vortex_expect("nullness handled in Scalar::cast")
33+
.clone(),
34+
)),
35+
))
2536
}
2637
}
2738

vortex-scalar/src/bool.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use vortex_dtype::Nullability::NonNullable;
22
use vortex_dtype::{DType, Nullability};
3-
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
3+
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult};
44

55
use crate::value::ScalarValue;
66
use crate::{InnerScalarValue, Scalar};
@@ -20,14 +20,14 @@ impl<'a> BoolScalar<'a> {
2020
self.value
2121
}
2222

23-
pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
24-
match dtype {
25-
DType::Bool(_) => Ok(Scalar::bool(
26-
self.value().ok_or_else(|| vortex_err!("not a bool"))?,
27-
dtype.nullability(),
28-
)),
29-
_ => vortex_bail!("Can't cast {} to bool", dtype),
23+
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
24+
if !matches!(dtype, DType::Bool(..)) {
25+
vortex_bail!("Can't cast bool to {}", dtype)
3026
}
27+
Ok(Scalar::bool(
28+
self.value.vortex_expect("nullness handled in Scalar::cast"),
29+
dtype.nullability(),
30+
))
3131
}
3232

3333
pub fn invert(self) -> BoolScalar<'a> {

vortex-scalar/src/extension.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,57 @@
11
use std::sync::Arc;
22

33
use vortex_dtype::{DType, ExtDType};
4-
use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult};
4+
use vortex_error::{vortex_bail, VortexError, VortexResult};
55

66
use crate::value::ScalarValue;
77
use crate::Scalar;
88

99
pub struct ExtScalar<'a> {
10-
dtype: &'a DType,
10+
ext_dtype: &'a ExtDType,
1111
value: &'a ScalarValue,
1212
}
1313

1414
impl<'a> ExtScalar<'a> {
1515
pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
16-
if !matches!(dtype, DType::Extension(..)) {
16+
let DType::Extension(ext_dtype) = dtype else {
1717
vortex_bail!("Expected extension scalar, found {}", dtype)
18-
}
19-
20-
Ok(Self { dtype, value })
21-
}
18+
};
2219

23-
#[inline]
24-
pub fn dtype(&self) -> &'a DType {
25-
self.dtype
20+
Ok(Self { ext_dtype, value })
2621
}
2722

2823
/// Returns the storage scalar of the extension scalar.
2924
pub fn storage(&self) -> Scalar {
30-
let storage_dtype = if let DType::Extension(ext_dtype) = self.dtype() {
31-
ext_dtype.storage_dtype().clone()
32-
} else {
33-
vortex_panic!("Expected extension DType: {}", self.dtype());
34-
};
35-
Scalar::new(storage_dtype, self.value.clone())
25+
Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
3626
}
3727

38-
pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
39-
todo!()
28+
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
29+
if self.value.is_null() && !dtype.is_nullable() {
30+
vortex_bail!(
31+
"cannot cast extension dtype with id {} and storage type {} to {}",
32+
self.ext_dtype.id(),
33+
self.ext_dtype.storage_dtype(),
34+
dtype
35+
);
36+
}
37+
38+
if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
39+
// Casting from an extension type to the underlying storage type is OK.
40+
return Ok(Scalar::new(dtype.clone(), self.value.clone()));
41+
}
42+
43+
if let DType::Extension(ext_dtype) = dtype {
44+
if self.ext_dtype.eq_ignore_nullability(ext_dtype) {
45+
return Ok(Scalar::new(dtype.clone(), self.value.clone()));
46+
}
47+
}
48+
49+
vortex_bail!(
50+
"cannot cast extension dtype with id {} and storage type {} to {}",
51+
self.ext_dtype.id(),
52+
self.ext_dtype.storage_dtype(),
53+
dtype
54+
);
4055
}
4156
}
4257

0 commit comments

Comments
 (0)