Skip to content

Commit 383e97a

Browse files
robert3005a10y
andauthored
Add helper function to unpack constant scalar out of array (#1373)
Co-authored-by: Andrew Duffy <[email protected]>
1 parent d2904cb commit 383e97a

File tree

14 files changed

+122
-97
lines changed

14 files changed

+122
-97
lines changed

encodings/alp/src/alp/compute.rs

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use vortex_array::array::ConstantArray;
2-
use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
2+
use vortex_array::compute::unary::{scalar_at_unchecked, ScalarAtFn};
33
use vortex_array::compute::{
44
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
55
SliceFn, TakeFn, TakeOptions,
66
};
7-
use vortex_array::stats::{ArrayStatistics, Stat};
87
use vortex_array::variants::PrimitiveArrayTrait;
98
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
109
use vortex_dtype::Nullability;
@@ -14,23 +13,23 @@ use vortex_scalar::{PValue, Scalar};
1413
use crate::{match_each_alp_float_ptype, ALPArray, ALPFloat};
1514

1615
impl ArrayCompute for ALPArray {
17-
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
18-
Some(self)
16+
fn compare(&self, other: &ArrayData, operator: Operator) -> Option<VortexResult<ArrayData>> {
17+
MaybeCompareFn::maybe_compare(self, other, operator)
1918
}
2019

21-
fn slice(&self) -> Option<&dyn SliceFn> {
20+
fn filter(&self) -> Option<&dyn FilterFn> {
2221
Some(self)
2322
}
2423

25-
fn take(&self) -> Option<&dyn TakeFn> {
24+
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
2625
Some(self)
2726
}
2827

29-
fn compare(&self, other: &ArrayData, operator: Operator) -> Option<VortexResult<ArrayData>> {
30-
MaybeCompareFn::maybe_compare(self, other, operator)
28+
fn slice(&self) -> Option<&dyn SliceFn> {
29+
Some(self)
3130
}
3231

33-
fn filter(&self) -> Option<&dyn FilterFn> {
32+
fn take(&self) -> Option<&dyn TakeFn> {
3433
Some(self)
3534
}
3635
}
@@ -102,14 +101,8 @@ impl MaybeCompareFn for ALPArray {
102101
array: &ArrayData,
103102
operator: Operator,
104103
) -> Option<VortexResult<ArrayData>> {
105-
if ConstantArray::try_from(array).is_ok()
106-
|| array
107-
.statistics()
108-
.get_as::<bool>(Stat::IsConstant)
109-
.unwrap_or_default()
110-
{
111-
let rhs = scalar_at(array, 0).vortex_expect("should be scalar");
112-
let pvalue = rhs
104+
if let Some(const_scalar) = array.as_constant() {
105+
let pvalue = const_scalar
113106
.value()
114107
.as_pvalue()
115108
.vortex_expect("Expected primitive value");

encodings/dict/src/compute.rs

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
use vortex_array::array::ConstantArray;
12
use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
23
use vortex_array::compute::{
34
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
45
SliceFn, TakeFn, TakeOptions,
56
};
6-
use vortex_array::stats::{ArrayStatistics, Stat};
77
use vortex_array::{ArrayData, IntoArrayData};
88
use vortex_error::{VortexExpect, VortexResult};
99
use vortex_scalar::Scalar;
@@ -39,17 +39,16 @@ impl MaybeCompareFn for DictArray {
3939
operator: Operator,
4040
) -> Option<VortexResult<ArrayData>> {
4141
// If the RHS is constant, then we just need to compare against our encoded values.
42-
if other
43-
.statistics()
44-
.get_as::<bool>(Stat::IsConstant)
45-
.unwrap_or_default()
46-
{
42+
if let Some(const_scalar) = other.as_constant() {
4743
return Some(
4844
// Ensure the other is the same length as the dictionary
49-
slice(other, 0, self.values().len())
50-
.and_then(|other| compare(self.values(), other, operator))
51-
.and_then(|values| Self::try_new(self.codes(), values))
52-
.map(|a| a.into_array()),
45+
compare(
46+
self.values(),
47+
ConstantArray::new(const_scalar, self.values().len()),
48+
operator,
49+
)
50+
.and_then(|values| Self::try_new(self.codes(), values))
51+
.map(|a| a.into_array()),
5352
);
5453
}
5554

@@ -102,14 +101,19 @@ impl SliceFn for DictArray {
102101
#[cfg(test)]
103102
mod test {
104103
use vortex_array::accessor::ArrayAccessor;
105-
use vortex_array::array::{PrimitiveArray, VarBinViewArray};
104+
use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray};
105+
use vortex_array::compute::unary::scalar_at;
106+
use vortex_array::compute::{compare, slice, Operator};
106107
use vortex_array::{IntoArrayData, IntoArrayVariant, ToArrayData};
107108
use vortex_dtype::{DType, Nullability};
109+
use vortex_scalar::Scalar;
108110

109-
use crate::{dict_encode_typed_primitive, dict_encode_varbinview, DictArray};
111+
use crate::{
112+
dict_encode_primitive, dict_encode_typed_primitive, dict_encode_varbinview, DictArray,
113+
};
110114

111115
#[test]
112-
fn flatten_nullable_primitive() {
116+
fn canonicalise_nullable_primitive() {
113117
let reference = PrimitiveArray::from_nullable_vec(vec![
114118
Some(42),
115119
Some(-9),
@@ -125,7 +129,7 @@ mod test {
125129
}
126130

127131
#[test]
128-
fn flatten_nullable_varbin() {
132+
fn canonicalise_nullable_varbin() {
129133
let reference = VarBinViewArray::from_iter(
130134
vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
131135
DType::Utf8(Nullability::Nullable),
@@ -147,4 +151,32 @@ mod test {
147151
.unwrap(),
148152
);
149153
}
154+
155+
#[test]
156+
fn compare_sliced_dict() {
157+
let reference = PrimitiveArray::from_nullable_vec(vec![
158+
Some(42),
159+
Some(-9),
160+
None,
161+
Some(42),
162+
Some(1),
163+
Some(5),
164+
]);
165+
let (codes, values) = dict_encode_primitive(&reference);
166+
let dict = DictArray::try_new(codes.into_array(), values.into_array()).unwrap();
167+
let sliced = slice(dict, 1, 4).unwrap();
168+
let compared = compare(sliced, ConstantArray::new(42, 3), Operator::Eq).unwrap();
169+
assert_eq!(
170+
scalar_at(&compared, 0).unwrap(),
171+
Scalar::bool(false, Nullability::Nullable)
172+
);
173+
assert_eq!(
174+
scalar_at(&compared, 1).unwrap(),
175+
Scalar::null(DType::Bool(Nullability::Nullable))
176+
);
177+
assert_eq!(
178+
scalar_at(compared, 2).unwrap(),
179+
Scalar::bool(true, Nullability::Nullable)
180+
);
181+
}
150182
}

encodings/fastlanes/src/for/compress.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ mod test {
140140
assert!(array.statistics().to_set().into_iter().next().is_none());
141141

142142
let compressed = for_compress(&array).unwrap();
143-
let constant = ConstantArray::try_from(compressed).unwrap();
144-
assert_eq!(constant.scalar_value(), &ScalarValue::from(0i32));
143+
let constant = compressed.as_constant().unwrap();
144+
assert_eq!(constant.value(), &ScalarValue::from(0i32));
145145
}
146146

147147
#[test]

encodings/fsst/src/compute.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ impl MaybeCompareFn for FSSTArray {
4141
other: &ArrayData,
4242
operator: Operator,
4343
) -> Option<VortexResult<ArrayData>> {
44-
match (ConstantArray::try_from(other), operator) {
45-
(Ok(constant_array), Operator::Eq | Operator::NotEq) => Some(compare_fsst_constant(
44+
match (other.as_constant(), operator) {
45+
(Some(constant_array), Operator::Eq | Operator::NotEq) => Some(compare_fsst_constant(
4646
self,
47-
&constant_array,
47+
&ConstantArray::new(constant_array, self.len()),
4848
operator == Operator::Eq,
4949
)),
5050
_ => None,

encodings/runend/src/compute.rs

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use vortex_array::compute::{
77
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
88
SliceFn, TakeFn, TakeOptions,
99
};
10-
use vortex_array::stats::{ArrayStatistics, Stat};
1110
use vortex_array::validity::Validity;
1211
use vortex_array::variants::PrimitiveArrayTrait;
1312
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
@@ -46,22 +45,15 @@ impl MaybeCompareFn for RunEndArray {
4645
operator: Operator,
4746
) -> Option<VortexResult<ArrayData>> {
4847
// If the RHS is constant, then we just need to compare against our encoded values.
49-
if other
50-
.statistics()
51-
.get_as::<bool>(Stat::IsConstant)
52-
.unwrap_or_default()
53-
{
54-
return Some(
55-
slice(other, 0, self.values().len())
56-
.and_then(|other| compare(self.values(), other, operator))
57-
.and_then(|values| {
58-
Self::try_new(self.ends(), values, self.validity().into_nullable())
59-
})
60-
.map(|a| a.into_array()),
61-
);
62-
}
63-
64-
None
48+
other.as_constant().map(|const_scalar| {
49+
compare(
50+
self.values(),
51+
ConstantArray::new(const_scalar, self.values().len()),
52+
operator,
53+
)
54+
.and_then(|values| Self::try_new(self.ends(), values, self.validity().into_nullable()))
55+
.map(|a| a.into_array())
56+
})
6557
}
6658
}
6759

vortex-array/src/array/constant/compute.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::cmp::Ordering;
22

33
use vortex_dtype::Nullability;
4-
use vortex_error::{vortex_bail, VortexExpect, VortexResult};
4+
use vortex_error::{vortex_bail, VortexResult};
55
use vortex_scalar::Scalar;
66

77
use crate::array::constant::ConstantArray;
@@ -98,15 +98,9 @@ impl MaybeCompareFn for ConstantArray {
9898
other: &ArrayData,
9999
operator: Operator,
100100
) -> Option<VortexResult<ArrayData>> {
101-
(ConstantArray::try_from(other).is_ok()
102-
|| other
103-
.statistics()
104-
.get_as::<bool>(Stat::IsConstant)
105-
.unwrap_or_default())
106-
.then(|| {
101+
other.as_constant().map(|const_scalar| {
107102
let lhs = self.owned_scalar();
108-
let rhs = scalar_at(other, 0).vortex_expect("Expected scalar");
109-
let scalar = scalar_cmp(&lhs, &rhs, operator);
103+
let scalar = scalar_cmp(&lhs, &const_scalar, operator);
110104
Ok(ConstantArray::new(scalar, self.len()).into_array())
111105
})
112106
}

vortex-array/src/array/extension/compute.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ impl MaybeCompareFn for ExtensionArray {
4141
other: &ArrayData,
4242
operator: Operator,
4343
) -> Option<VortexResult<ArrayData>> {
44-
if let Ok(const_ext) = ConstantArray::try_from(other) {
45-
let scalar_ext = ExtScalar::try_new(const_ext.dtype(), const_ext.scalar_value())
44+
if let Some(const_ext) = other.as_constant() {
45+
let scalar_ext = ExtScalar::try_new(const_ext.dtype(), const_ext.value())
4646
.vortex_expect("Expected ExtScalar");
4747
let const_storage = ConstantArray::new(
4848
Scalar::new(self.storage().dtype().clone(), scalar_ext.value().clone()),
49-
const_ext.len(),
49+
self.len(),
5050
);
5151

5252
return Some(compare(self.storage(), const_storage, operator));

vortex-array/src/array/varbin/compute/compare.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ impl MaybeCompareFn for VarBinArray {
1919
other: &ArrayData,
2020
operator: Operator,
2121
) -> Option<VortexResult<ArrayData>> {
22-
if let Ok(rhs_const) = ConstantArray::try_from(other) {
23-
Some(compare_constant(self, &rhs_const, operator))
24-
} else {
25-
None
26-
}
22+
other.as_constant().map(|rhs_const| {
23+
compare_constant(self, &ConstantArray::new(rhs_const, self.len()), operator)
24+
})
2725
}
2826
}
2927

vortex-array/src/array/varbinview/compute.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,9 @@ impl MaybeCompareFn for VarBinViewArray {
131131
other: &ArrayData,
132132
operator: Operator,
133133
) -> Option<VortexResult<ArrayData>> {
134-
if let Ok(rhs_const) = ConstantArray::try_from(other) {
135-
Some(compare_constant(self, &rhs_const, operator))
136-
} else {
137-
None
138-
}
134+
other.as_constant().map(|rhs_const| {
135+
compare_constant(self, &ConstantArray::new(rhs_const, self.len()), operator)
136+
})
139137
}
140138
}
141139

vortex-array/src/compute/compare.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ mod tests {
258258
let left = ConstantArray::new(Scalar::from(2u32), 10);
259259
let right = ConstantArray::new(Scalar::from(10u32), 10);
260260

261-
let res = ConstantArray::try_from(compare(left, right, Operator::Gt).unwrap()).unwrap();
262-
assert_eq!(res.scalar_value(), &ScalarValue::Bool(false));
263-
assert_eq!(res.len(), 10);
261+
let compare = compare(left, right, Operator::Gt).unwrap();
262+
let res = compare.as_constant().unwrap();
263+
assert_eq!(res.value(), &ScalarValue::Bool(false));
264+
assert_eq!(compare.len(), 10);
264265
}
265266
}

0 commit comments

Comments
 (0)