Skip to content

Commit eadc1fe

Browse files
authored
feat: mask (#1900)
Mask sets entries of an array to null. I like the analogy to light: the array is a sequence of lights (each value might be a different wavelength). Null is represented by the absence oflight. Placing a mask (i.e. a piece of plastic with slits) over the array causes those values where the mask is present (i.e. "on", "true") to be dark. An example in pseudo-code: ```rust a = [1, 2, 3, 4, 5] a_mask = [t, f, f, t, f] mask(a, a_mask) == [null, 2, 3, null, 5] ``` Specializations --------------- I only fallback to Arrow for two of the core arrays: - Sparse. I was skeptical that I could do better than decompressing and applying it. - Constant. If the mask is sparse, SparseArray might be a good choice. I didn't investigate. For the non-core arrays, I'm missing the following. I'm not clear that I can beat decompression forrun end. The others are easy enough but some amount of typing and testing. - fastlanes - fsst - roaring - runend - runend-bool - zigzag Naming ------ Pandas also calls this operation [`mask`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.mask.html) but accepts an optional second argument which is an array of values to use instead of null (which makes Pandas' mask more like an `if_else`). Arrow-rs calls this [`nullif`](https://arrow.apache.org/rust/arrow/compute/fn.nullif.html). Arrow-cpp has [`if_else(condition, consequent, alternate)`](https://arrow.apache.org/docs/cpp/compute.html#cpp-compute-scalar-selections) and [`replace_with_mask(array, mask, replacements)`](https://arrow.apache.org/docs/cpp/compute.html#replace-functions) both of which can implement our `mask` by passing a `NullArray` as the third argument.
1 parent f420bca commit eadc1fe

File tree

32 files changed

+1220
-84
lines changed

32 files changed

+1220
-84
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use vortex_array::compute::{mask, MaskFn};
2+
use vortex_array::{Array, IntoArray};
3+
use vortex_error::VortexResult;
4+
use vortex_mask::Mask;
5+
6+
use crate::{ALPRDArray, ALPRDEncoding};
7+
8+
impl MaskFn<ALPRDArray> for ALPRDEncoding {
9+
fn mask(&self, array: &ALPRDArray, filter_mask: Mask) -> VortexResult<Array> {
10+
Ok(ALPRDArray::try_new(
11+
array.dtype().as_nullable(),
12+
mask(&array.left_parts(), filter_mask)?,
13+
array.left_parts_dict(),
14+
array.right_parts(),
15+
array.right_bit_width(),
16+
array.left_parts_patches(),
17+
)?
18+
.into_array())
19+
}
20+
}
21+
22+
#[cfg(test)]
23+
mod tests {
24+
use rstest::rstest;
25+
use vortex_array::array::PrimitiveArray;
26+
use vortex_array::compute::test_harness::test_mask;
27+
use vortex_array::IntoArray as _;
28+
29+
use crate::{ALPRDFloat, RDEncoder};
30+
31+
#[rstest]
32+
#[case(0.1f32, 0.2f32, 3e25f32)]
33+
#[case(0.1f64, 0.2f64, 3e100f64)]
34+
fn test_mask_simple<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
35+
test_mask(
36+
RDEncoder::new(&[a, b])
37+
.encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
38+
.into_array(),
39+
);
40+
}
41+
42+
#[rstest]
43+
#[case(0.1f32, 3e25f32)]
44+
#[case(0.5f64, 1e100f64)]
45+
fn test_mask_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
46+
test_mask(
47+
RDEncoder::new(&[a])
48+
.encode(&PrimitiveArray::from_option_iter([
49+
Some(a),
50+
None,
51+
Some(outlier),
52+
Some(a),
53+
None,
54+
]))
55+
.into_array(),
56+
);
57+
}
58+
}

encodings/alp/src/alp_rd/compute/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use vortex_array::compute::{FilterFn, ScalarAtFn, SliceFn, TakeFn};
1+
use vortex_array::compute::{FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
22
use vortex_array::vtable::ComputeVTable;
33
use vortex_array::Array;
44

55
use crate::ALPRDEncoding;
66

77
mod filter;
8+
mod mask;
89
mod scalar_at;
910
mod slice;
1011
mod take;
@@ -14,6 +15,10 @@ impl ComputeVTable for ALPRDEncoding {
1415
Some(self)
1516
}
1617

18+
fn mask_fn(&self) -> Option<&dyn MaskFn<Array>> {
19+
Some(self)
20+
}
21+
1722
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<Array>> {
1823
Some(self)
1924
}

encodings/bytebool/src/compute.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use num_traits::AsPrimitive;
2-
use vortex_array::compute::{FillForwardFn, ScalarAtFn, SliceFn, TakeFn};
2+
use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
33
use vortex_array::validity::Validity;
44
use vortex_array::variants::PrimitiveArrayTrait;
55
use vortex_array::vtable::ComputeVTable;
@@ -16,6 +16,10 @@ impl ComputeVTable for ByteBoolEncoding {
1616
None
1717
}
1818

19+
fn mask_fn(&self) -> Option<&dyn MaskFn<Array>> {
20+
Some(self)
21+
}
22+
1923
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<Array>> {
2024
Some(self)
2125
}
@@ -29,6 +33,13 @@ impl ComputeVTable for ByteBoolEncoding {
2933
}
3034
}
3135

36+
impl MaskFn<ByteBoolArray> for ByteBoolEncoding {
37+
fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult<Array> {
38+
ByteBoolArray::try_new(array.buffer().clone(), array.validity().mask(&mask)?)
39+
.map(IntoArray::into_array)
40+
}
41+
}
42+
3243
impl ScalarAtFn<ByteBoolArray> for ByteBoolEncoding {
3344
fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
3445
Ok(Scalar::bool(
@@ -139,6 +150,7 @@ impl FillForwardFn<ByteBoolArray> for ByteBoolEncoding {
139150

140151
#[cfg(test)]
141152
mod tests {
153+
use vortex_array::compute::test_harness::test_mask;
142154
use vortex_array::compute::{compare, scalar_at, slice, Operator};
143155

144156
use super::*;
@@ -211,4 +223,12 @@ mod tests {
211223
let s = scalar_at(&arr, 4).unwrap();
212224
assert!(s.is_null());
213225
}
226+
227+
#[test]
228+
fn test_mask_byte_bool() {
229+
test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).into_array());
230+
test_mask(
231+
ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
232+
);
233+
}
214234
}

encodings/dict/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ name = "dict_compare"
5050
harness = false
5151
required-features = ["test-harness"]
5252

53+
[[bench]]
54+
name = "dict_mask"
55+
harness = false
56+
5357
[[bench]]
5458
name = "chunked_dict_array_builder"
5559
harness = false
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#![allow(clippy::unwrap_used)]
2+
3+
use divan::Bencher;
4+
use rand::rngs::StdRng;
5+
use rand::{Rng, SeedableRng as _};
6+
use vortex_array::array::PrimitiveArray;
7+
use vortex_array::compute::mask;
8+
use vortex_array::IntoArray as _;
9+
use vortex_dict::DictArray;
10+
use vortex_mask::Mask;
11+
12+
fn main() {
13+
divan::main();
14+
}
15+
16+
fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> Mask {
17+
let indices = (0..len)
18+
.filter(|_| rng.gen_bool(fraction_masked))
19+
.collect::<Vec<usize>>();
20+
Mask::from_indices(len, indices)
21+
}
22+
23+
#[divan::bench(args = [
24+
(0.9, 0.9),
25+
(0.9, 0.5),
26+
(0.9, 0.1),
27+
(0.9, 0.01),
28+
(0.5, 0.9),
29+
(0.5, 0.5),
30+
(0.5, 0.1),
31+
(0.5, 0.01),
32+
(0.1, 0.9),
33+
(0.1, 0.5),
34+
(0.1, 0.1),
35+
(0.1, 0.01),
36+
(0.01, 0.9),
37+
(0.01, 0.5),
38+
(0.01, 0.1),
39+
(0.01, 0.01),
40+
])]
41+
fn bench_dict_mask(bencher: Bencher, (fraction_valid, fraction_masked): (f64, f64)) {
42+
let mut rng = StdRng::seed_from_u64(0);
43+
44+
let len = 65_535;
45+
let codes = PrimitiveArray::from_iter((0..len).map(|_| {
46+
if rng.gen_bool(fraction_valid) {
47+
1u64
48+
} else {
49+
0u64
50+
}
51+
}))
52+
.into_array();
53+
let values = PrimitiveArray::from_option_iter([None, Some(42i32)]).into_array();
54+
let array = DictArray::try_new(codes, values).unwrap().into_array();
55+
let filter_mask = filter_mask(len, fraction_masked, &mut rng);
56+
bencher
57+
.with_inputs(|| (&array, filter_mask.clone()))
58+
.bench_values(|(array, filter_mask)| mask(array, filter_mask).unwrap());
59+
}

encodings/dict/src/compute/mod.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ impl SliceFn<DictArray> for DictEncoding {
7979
#[cfg(test)]
8080
mod test {
8181
use vortex_array::accessor::ArrayAccessor;
82-
use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray};
82+
use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
83+
use vortex_array::compute::test_harness::test_mask;
8384
use vortex_array::compute::{compare, scalar_at, slice, Operator};
8485
use vortex_array::{Array, IntoArray, IntoArrayVariant};
8586
use vortex_dtype::{DType, Nullability};
@@ -198,4 +199,37 @@ mod test {
198199
Scalar::bool(true, Nullability::Nullable)
199200
);
200201
}
202+
203+
#[test]
204+
fn test_mask_dict_array() {
205+
let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array())
206+
.unwrap()
207+
.into_array();
208+
test_mask(array);
209+
210+
let array = dict_encode(
211+
&PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
212+
.into_array(),
213+
)
214+
.unwrap()
215+
.into_array();
216+
test_mask(array);
217+
218+
let array = dict_encode(
219+
&VarBinArray::from_iter(
220+
[
221+
Some("hello"),
222+
None,
223+
Some("hello"),
224+
Some("good"),
225+
Some("good"),
226+
],
227+
DType::Utf8(Nullability::Nullable),
228+
)
229+
.into_array(),
230+
)
231+
.unwrap()
232+
.into_array();
233+
test_mask(array);
234+
}
201235
}

encodings/sparse/src/compute/mod.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,14 @@ impl FilterFn<SparseArray> for SparseEncoding {
104104
mod test {
105105
use rstest::{fixture, rstest};
106106
use vortex_array::array::PrimitiveArray;
107-
use vortex_array::compute::test_harness::test_binary_numeric;
108-
use vortex_array::compute::{filter, search_sorted, slice, SearchResult, SearchSortedSide};
107+
use vortex_array::compute::test_harness::{test_binary_numeric, test_mask};
108+
use vortex_array::compute::{
109+
filter, search_sorted, slice, try_cast, SearchResult, SearchSortedSide,
110+
};
109111
use vortex_array::validity::Validity;
110112
use vortex_array::{Array, IntoArray, IntoArrayVariant};
111113
use vortex_buffer::buffer;
114+
use vortex_dtype::{DType, Nullability, PType};
112115
use vortex_mask::Mask;
113116
use vortex_scalar::Scalar;
114117

@@ -223,4 +226,35 @@ mod test {
223226
fn test_sparse_binary_numeric(array: Array) {
224227
test_binary_numeric::<i32>(array)
225228
}
229+
230+
#[test]
231+
fn test_mask_sparse_array() {
232+
let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
233+
test_mask(
234+
SparseArray::try_new(
235+
buffer![1u64, 2, 4].into_array(),
236+
try_cast(
237+
buffer![100i32, 200, 300].into_array(),
238+
null_fill_value.dtype(),
239+
)
240+
.unwrap(),
241+
5,
242+
null_fill_value,
243+
)
244+
.unwrap()
245+
.into_array(),
246+
);
247+
248+
let ten_fill_value = Scalar::from(10i32);
249+
test_mask(
250+
SparseArray::try_new(
251+
buffer![1u64, 2, 4].into_array(),
252+
buffer![100i32, 200, 300].into_array(),
253+
5,
254+
ten_fill_value,
255+
)
256+
.unwrap()
257+
.into_array(),
258+
)
259+
}
226260
}

vortex-array/src/array/bool/compute/cast.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,9 @@ impl CastFn<BoolArray> for BoolEncoding {
1111
vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype);
1212
}
1313

14-
// If the types are the same, return the array,
15-
// otherwise set the array nullability as the dtype nullability.
16-
if dtype.is_nullable() || array.all_valid()? {
17-
Ok(BoolArray::new(array.boolean_buffer(), dtype.nullability()).into_array())
18-
} else {
19-
vortex_bail!("Cannot cast null array to non-nullable type");
20-
}
14+
let new_nullability = dtype.nullability();
15+
let new_validity = array.validity().cast_nullability(new_nullability)?;
16+
BoolArray::try_new(array.boolean_buffer(), new_validity).map(IntoArray::into_array)
2117
}
2218
}
2319

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use vortex_error::VortexResult;
2+
use vortex_mask::Mask;
3+
4+
use crate::array::{BoolArray, BoolEncoding};
5+
use crate::compute::MaskFn;
6+
use crate::{Array, IntoArray};
7+
8+
impl MaskFn<BoolArray> for BoolEncoding {
9+
fn mask(&self, array: &BoolArray, mask: Mask) -> VortexResult<Array> {
10+
BoolArray::try_new(array.boolean_buffer(), array.validity().mask(&mask)?)
11+
.map(IntoArray::into_array)
12+
}
13+
}

vortex-array/src/array/bool/compute/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::array::BoolEncoding;
22
use crate::compute::{
3-
BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MinMaxFn, ScalarAtFn,
4-
SliceFn, TakeFn, ToArrowFn,
3+
BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MaskFn, MinMaxFn,
4+
ScalarAtFn, SliceFn, TakeFn, ToArrowFn,
55
};
66
use crate::vtable::ComputeVTable;
77
use crate::Array;
@@ -12,6 +12,7 @@ mod fill_null;
1212
pub mod filter;
1313
mod flatten;
1414
mod invert;
15+
mod mask;
1516
mod min_max;
1617
mod scalar_at;
1718
mod slice;
@@ -43,6 +44,10 @@ impl ComputeVTable for BoolEncoding {
4344
Some(self)
4445
}
4546

47+
fn mask_fn(&self) -> Option<&dyn MaskFn<Array>> {
48+
Some(self)
49+
}
50+
4651
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<Array>> {
4752
Some(self)
4853
}

0 commit comments

Comments
 (0)