Skip to content

Commit 9ef1107

Browse files
authored
Refactor Vortex Mask (#2101)
We now explicitly expose an AllTrue, AllFalse, and Values (mixed) variant.
1 parent e67f256 commit 9ef1107

File tree

65 files changed

+788
-1022
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+788
-1022
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/alp/src/alp/array.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use vortex_array::encoding::ids;
66
use vortex_array::patches::{Patches, PatchesMetadata};
77
use vortex_array::stats::StatisticsVTable;
88
use vortex_array::validate::ValidateVTable;
9-
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
9+
use vortex_array::validity::{ArrayValidity, ValidityVTable};
1010
use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable};
1111
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
1212
use vortex_array::{
@@ -15,6 +15,7 @@ use vortex_array::{
1515
};
1616
use vortex_dtype::{DType, PType};
1717
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult};
18+
use vortex_mask::Mask;
1819

1920
use crate::alp::{alp_encode, decompress, Exponents};
2021

@@ -124,7 +125,7 @@ impl ValidityVTable<ALPArray> for ALPEncoding {
124125
array.encoded().is_valid(index)
125126
}
126127

127-
fn logical_validity(&self, array: &ALPArray) -> VortexResult<LogicalValidity> {
128+
fn logical_validity(&self, array: &ALPArray) -> VortexResult<Mask> {
128129
array.encoded().logical_validity()
129130
}
130131
}

encodings/alp/src/alp_rd/array.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ use vortex_array::encoding::ids;
66
use vortex_array::patches::{Patches, PatchesMetadata};
77
use vortex_array::stats::{StatisticsVTable, StatsSet};
88
use vortex_array::validate::ValidateVTable;
9-
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
9+
use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable};
1010
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
1111
use vortex_array::{
1212
impl_encoding, ArrayDType, ArrayData, ArrayLen, Canonical, IntoCanonical, SerdeMetadata,
1313
};
1414
use vortex_dtype::{DType, Nullability, PType};
1515
use vortex_error::{vortex_bail, VortexExpect, VortexResult};
16+
use vortex_mask::Mask;
1617

1718
use crate::alp_rd::alp_rd_decode;
1819

@@ -210,8 +211,7 @@ impl IntoCanonical for ALPRDArray {
210211
right_parts.into_buffer_mut::<u32>(),
211212
self.left_parts_patches(),
212213
)?,
213-
self.logical_validity()?
214-
.into_validity(self.dtype().nullability()),
214+
Validity::from_mask(self.logical_validity()?, self.dtype().nullability()),
215215
)
216216
} else {
217217
PrimitiveArray::new(
@@ -222,8 +222,7 @@ impl IntoCanonical for ALPRDArray {
222222
right_parts.into_buffer_mut::<u64>(),
223223
self.left_parts_patches(),
224224
)?,
225-
self.logical_validity()?
226-
.into_validity(self.dtype().nullability()),
225+
Validity::from_mask(self.logical_validity()?, self.dtype().nullability()),
227226
)
228227
};
229228

@@ -237,7 +236,7 @@ impl ValidityVTable<ALPRDArray> for ALPRDEncoding {
237236
array.left_parts().is_valid(index)
238237
}
239238

240-
fn logical_validity(&self, array: &ALPRDArray) -> VortexResult<LogicalValidity> {
239+
fn logical_validity(&self, array: &ALPRDArray) -> VortexResult<Mask> {
241240
// Use validity from left_parts
242241
array.left_parts().logical_validity()
243242
}

encodings/bytebool/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ vortex-array = { workspace = true }
2424
vortex-buffer = { workspace = true }
2525
vortex-dtype = { workspace = true }
2626
vortex-error = { workspace = true }
27+
vortex-mask = { workspace = true }
2728
vortex-scalar = { workspace = true }
2829

2930
[dev-dependencies]

encodings/bytebool/src/array.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ use vortex_array::array::BoolArray;
66
use vortex_array::encoding::ids;
77
use vortex_array::stats::StatsSet;
88
use vortex_array::validate::ValidateVTable;
9-
use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable};
9+
use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable};
1010
use vortex_array::variants::{BoolArrayTrait, VariantsVTable};
1111
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
1212
use vortex_array::{impl_encoding, ArrayLen, Canonical, IntoCanonical, SerdeMetadata};
1313
use vortex_buffer::ByteBuffer;
1414
use vortex_dtype::DType;
1515
use vortex_error::{VortexExpect as _, VortexResult};
16+
use vortex_mask::Mask;
1617

1718
impl_encoding!(
1819
"vortex.bytebool",
@@ -116,7 +117,7 @@ impl ValidityVTable<ByteBoolArray> for ByteBoolEncoding {
116117
array.validity().is_valid(index)
117118
}
118119

119-
fn logical_validity(&self, array: &ByteBoolArray) -> VortexResult<LogicalValidity> {
120+
fn logical_validity(&self, array: &ByteBoolArray) -> VortexResult<Mask> {
120121
array.validity().to_logical(array.len())
121122
}
122123
}

encodings/bytebool/src/compute.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use num_traits::AsPrimitive;
22
use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn};
3-
use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity};
3+
use vortex_array::validity::{ArrayValidity, Validity};
44
use vortex_array::variants::PrimitiveArrayTrait;
55
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
66
use vortex_dtype::{match_each_integer_ptype, Nullability};
77
use vortex_error::{vortex_err, VortexResult};
8+
use vortex_mask::Mask;
89
use vortex_scalar::Scalar;
910

1011
use super::{ByteBoolArray, ByteBoolEncoding};
@@ -55,7 +56,7 @@ impl TakeFn<ByteBoolArray> for ByteBoolEncoding {
5556
// FIXME(ngates): we should be operating over canonical validity, which doesn't
5657
// have fallible is_valid function.
5758
let arr = match validity {
58-
LogicalValidity::AllValid(_) => {
59+
Mask::AllTrue(_) => {
5960
let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
6061
indices.as_slice::<$I>()
6162
.iter()
@@ -68,16 +69,14 @@ impl TakeFn<ByteBoolArray> for ByteBoolEncoding {
6869

6970
ByteBoolArray::from(bools).into_array()
7071
}
71-
LogicalValidity::AllInvalid(_) => {
72-
ByteBoolArray::from(vec![None; indices.len()]).into_array()
73-
}
74-
LogicalValidity::Mask(mask) => {
72+
Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
73+
Mask::Values(values) => {
7574
let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
7675
indices.as_slice::<$I>()
7776
.iter()
7877
.map(|&idx| {
7978
let idx = idx.as_();
80-
if mask.value(idx) {
79+
if values.value(idx) {
8180
Some(bools[idx])
8281
} else {
8382
None
@@ -101,21 +100,21 @@ impl FillForwardFn<ByteBoolArray> for ByteBoolEncoding {
101100
return Ok(array.to_array());
102101
}
103102
// all valid, but we need to convert to non-nullable
104-
if validity.all_valid() {
103+
if validity.all_true() {
105104
return Ok(
106105
ByteBoolArray::try_new(array.buffer().clone(), Validity::AllValid)?.into_array(),
107106
);
108107
}
109108
// all invalid => fill with default value (false)
110-
if validity.all_invalid() {
109+
if validity.all_false() {
111110
return Ok(
112111
ByteBoolArray::try_from_vec(vec![false; array.len()], Validity::AllValid)?
113112
.into_array(),
114113
);
115114
}
116115

117116
let validity = validity
118-
.to_null_buffer()?
117+
.to_null_buffer()
119118
.ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;
120119

121120
let bools = array.as_slice();

encodings/datetime-parts/src/array.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ use vortex_array::compute::try_cast;
66
use vortex_array::encoding::ids;
77
use vortex_array::stats::StatsSet;
88
use vortex_array::validate::ValidateVTable;
9-
use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable};
9+
use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable};
1010
use vortex_array::variants::{ExtensionArrayTrait, VariantsVTable};
1111
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
1212
use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, IntoArrayData, SerdeMetadata};
1313
use vortex_dtype::{DType, PType};
1414
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult, VortexUnwrap};
15+
use vortex_mask::Mask;
1516

1617
impl_encoding!(
1718
"vortex.datetimeparts",
@@ -100,10 +101,10 @@ impl DateTimePartsArray {
100101

101102
pub fn validity(&self) -> VortexResult<Validity> {
102103
// FIXME(ngates): this function is weird... can we just use logical validity?
103-
Ok(self
104-
.days()
105-
.logical_validity()?
106-
.into_validity(self.dtype().nullability()))
104+
Ok(Validity::from_mask(
105+
self.days().logical_validity()?,
106+
self.dtype().nullability(),
107+
))
107108
}
108109
}
109110

@@ -140,7 +141,7 @@ impl ValidityVTable<DateTimePartsArray> for DateTimePartsEncoding {
140141
array.days().is_valid(index)
141142
}
142143

143-
fn logical_validity(&self, array: &DateTimePartsArray) -> VortexResult<LogicalValidity> {
144+
fn logical_validity(&self, array: &DateTimePartsArray) -> VortexResult<Mask> {
144145
array.days().logical_validity()
145146
}
146147
}

encodings/dict/src/array.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use vortex_array::compute::{scalar_at, take};
66
use vortex_array::encoding::ids;
77
use vortex_array::stats::StatsSet;
88
use vortex_array::validate::ValidateVTable;
9-
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
9+
use vortex_array::validity::{ArrayValidity, ValidityVTable};
1010
use vortex_array::variants::PrimitiveArrayTrait;
1111
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
1212
use vortex_array::{
@@ -90,7 +90,7 @@ impl ValidityVTable<DictArray> for DictEncoding {
9090
array.values().is_valid(values_index)
9191
}
9292

93-
fn logical_validity(&self, array: &DictArray) -> VortexResult<LogicalValidity> {
93+
fn logical_validity(&self, array: &DictArray) -> VortexResult<Mask> {
9494
if array.dtype().is_nullable() {
9595
let primitive_codes = array.codes().into_primitive()?;
9696
match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
@@ -99,10 +99,10 @@ impl ValidityVTable<DictArray> for DictEncoding {
9999
let is_valid_buffer = BooleanBuffer::collect_bool(is_valid.len(), |idx| {
100100
is_valid[idx] != 0
101101
});
102-
Ok(LogicalValidity::Mask(Mask::from_buffer(is_valid_buffer)))
102+
Ok(Mask::from_buffer(is_valid_buffer))
103103
})
104104
} else {
105-
Ok(LogicalValidity::AllValid(array.len()))
105+
Ok(Mask::AllTrue(array.len()))
106106
}
107107
}
108108
}

encodings/fastlanes/src/bitpacking/compress.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ mod test {
413413
.unwrap()
414414
.to_null_buffer()
415415
.unwrap()
416-
.unwrap()
417416
.into_inner()
418417
.set_indices()
419418
.collect::<Vec<_>>()

encodings/fastlanes/src/bitpacking/compute/filter.rs

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use vortex_array::variants::PrimitiveArrayTrait;
66
use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant};
77
use vortex_buffer::{Buffer, BufferMut};
88
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
9-
use vortex_error::VortexResult;
10-
use vortex_mask::{Mask, MaskIter};
9+
use vortex_error::{VortexExpect, VortexResult};
10+
use vortex_mask::Mask;
1111

1212
use super::chunked_indices;
1313
use crate::bitpacking::compute::take::UNPACK_CHUNK_THRESHOLD;
@@ -43,17 +43,20 @@ fn filter_primitive<T: NativePType + BitPacking + ArrowNativeType>(
4343
.flatten();
4444

4545
// Short-circuit if the selectivity is high enough.
46-
if mask.selectivity() > 0.8 {
46+
if mask.density() > 0.8 {
4747
return filter(array.clone().into_primitive()?.as_ref(), mask)
4848
.and_then(|a| a.into_primitive());
4949
}
5050

51-
let values: Buffer<T> = match mask.iter() {
52-
MaskIter::Indices(indices) => {
53-
filter_indices(array, mask.true_count(), indices.iter().copied())
54-
}
55-
MaskIter::Slices(slices) => filter_slices(array, mask.true_count(), slices.iter().copied()),
56-
};
51+
let values: Buffer<T> = filter_indices(
52+
array,
53+
mask.true_count(),
54+
mask.values()
55+
.vortex_expect("AllTrue and AllFalse handled by filter fn")
56+
.indices()
57+
.iter()
58+
.copied(),
59+
);
5760

5861
let mut values = PrimitiveArray::new(values, validity).reinterpret_cast(array.ptype());
5962
if let Some(patches) = patches {
@@ -111,19 +114,6 @@ fn filter_indices<T: NativePType + BitPacking + ArrowNativeType>(
111114
values.freeze()
112115
}
113116

114-
fn filter_slices<T: NativePType + BitPacking + ArrowNativeType>(
115-
array: &BitPackedArray,
116-
indices_len: usize,
117-
slices: impl Iterator<Item = (usize, usize)>,
118-
) -> Buffer<T> {
119-
// TODO(ngates): do this more efficiently.
120-
filter_indices(
121-
array,
122-
indices_len,
123-
slices.into_iter().flat_map(|(start, end)| start..end),
124-
)
125-
}
126-
127117
#[cfg(test)]
128118
mod test {
129119
use vortex_array::array::PrimitiveArray;

0 commit comments

Comments
 (0)