Skip to content

Commit 51401ca

Browse files
authored
Add MaskFn (#5606)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent a09481c commit 51401ca

File tree

36 files changed

+324
-150
lines changed

36 files changed

+324
-150
lines changed

vortex-array/src/arrays/masked/vtable/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ mod operations;
77
mod validity;
88

99
use vortex_buffer::BufferHandle;
10-
use vortex_compute::mask::MaskValidity;
1110
use vortex_dtype::DType;
1211
use vortex_error::VortexResult;
1312
use vortex_error::vortex_bail;
1413
use vortex_vector::Vector;
14+
use vortex_vector::VectorOps;
1515

1616
use crate::ArrayBufferVisitor;
1717
use crate::ArrayChildVisitor;
@@ -106,8 +106,9 @@ impl VTable for MaskedVTable {
106106
}
107107

108108
fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
109-
let vector = array.child().batch_execute(ctx)?;
110-
Ok(MaskValidity::mask_validity(vector, &array.validity_mask()))
109+
let mut vector = array.child().batch_execute(ctx)?;
110+
vector.mask_validity(&array.validity_mask());
111+
Ok(vector)
111112
}
112113
}
113114

vortex-array/src/expr/exprs/get_item/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use std::fmt::Formatter;
77
use std::ops::Not;
88

99
use prost::Message;
10-
use vortex_compute::mask::MaskValidity;
1110
use vortex_dtype::DType;
1211
use vortex_dtype::FieldName;
1312
use vortex_dtype::FieldPath;
@@ -145,8 +144,8 @@ impl VTable for GetItem {
145144
.into_struct();
146145

147146
// We must intersect the validity with that of the parent struct
148-
let field = struct_vector.fields()[field_idx].clone();
149-
let field = MaskValidity::mask_validity(field, struct_vector.validity());
147+
let mut field = struct_vector.fields()[field_idx].clone();
148+
field.mask_validity(struct_vector.validity());
150149

151150
Ok(field)
152151
}

vortex-array/src/expr/functions/vtable.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,15 @@ pub trait VTable: 'static + Send + Sync + Sized {
110110
/// The arity (number of arguments) of a function.
111111
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
112112
pub enum Arity {
113-
Fixed(usize),
113+
Exact(usize),
114114
Variadic { min: usize, max: Option<usize> },
115115
}
116116

117117
impl Arity {
118118
/// Whether the given argument count matches this arity.
119119
pub fn matches(&self, arg_count: usize) -> bool {
120120
match self {
121-
Arity::Fixed(m) => *m == arg_count,
121+
Arity::Exact(m) => *m == arg_count,
122122
Arity::Variadic { min, max } => {
123123
if arg_count < *min {
124124
return false;

vortex-array/src/scalar_fns/cast/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ mod test {
4141
use crate::arrays::ConstantArray;
4242
use crate::arrays::ConstantVTable;
4343
use crate::optimizer::ArrayOptimizer;
44-
use crate::scalar_fns::BuiltinScalarFns;
44+
use crate::scalar_fns::ArrayBuiltins;
4545

4646
#[test]
4747
fn test_same_dtype() -> VortexResult<()> {

vortex-array/src/scalar_fns/cast/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::expr::functions::FunctionId;
2121
use crate::expr::functions::NullHandling;
2222
use crate::expr::functions::VTable;
2323
use crate::expr::stats::Stat;
24-
use crate::scalar_fns::BuiltinScalarFns;
24+
use crate::scalar_fns::ExprBuiltins;
2525

2626
pub struct CastFn;
2727
impl VTable for CastFn {
@@ -49,7 +49,7 @@ impl VTable for CastFn {
4949
}
5050

5151
fn arity(&self, _options: &DType) -> Arity {
52-
Arity::Fixed(1)
52+
Arity::Exact(1)
5353
}
5454

5555
fn null_handling(&self, _options: &DType) -> NullHandling {

vortex-array/src/scalar_fns/is_null/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl VTable for IsNullFn {
3030
}
3131

3232
fn arity(&self, _: &Self::Options) -> Arity {
33-
Arity::Fixed(1)
33+
Arity::Exact(1)
3434
}
3535

3636
fn arg_name(&self, _: &Self::Options, arg_idx: usize) -> ArgName {
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
use vortex_dtype::Nullability;
6+
use vortex_error::VortexExpect;
7+
use vortex_error::VortexResult;
8+
use vortex_error::vortex_ensure;
9+
use vortex_mask::Mask;
10+
use vortex_vector::BoolDatum;
11+
use vortex_vector::Datum;
12+
use vortex_vector::ScalarOps;
13+
use vortex_vector::VectorMutOps;
14+
use vortex_vector::VectorOps;
15+
16+
use crate::expr::functions::ArgName;
17+
use crate::expr::functions::Arity;
18+
use crate::expr::functions::EmptyOptions;
19+
use crate::expr::functions::ExecutionArgs;
20+
use crate::expr::functions::FunctionId;
21+
use crate::expr::functions::VTable;
22+
23+
/// A function that intersects the validity of an array using another array as a mask.
24+
///
25+
/// Where the `mask` array is true, the corresponding v
26+
pub struct MaskFn;
27+
impl VTable for MaskFn {
28+
type Options = EmptyOptions;
29+
30+
fn id(&self) -> FunctionId {
31+
FunctionId::from("vortex.mask")
32+
}
33+
34+
fn arity(&self, _options: &Self::Options) -> Arity {
35+
Arity::Exact(2)
36+
}
37+
38+
fn arg_name(&self, _options: &Self::Options, arg_idx: usize) -> ArgName {
39+
match arg_idx {
40+
0 => ArgName::from("input"),
41+
1 => ArgName::from("mask"),
42+
_ => unreachable!("unknown"),
43+
}
44+
}
45+
46+
fn return_dtype(&self, _options: &Self::Options, arg_types: &[DType]) -> VortexResult<DType> {
47+
vortex_ensure!(
48+
arg_types[1] == DType::Bool(Nullability::NonNullable),
49+
"The mask argument to 'mask' must be a non-nullable boolean array, got {}",
50+
arg_types[1]
51+
);
52+
Ok(arg_types[0].as_nullable())
53+
}
54+
55+
fn execute(&self, _options: &Self::Options, args: &ExecutionArgs) -> VortexResult<Datum> {
56+
let input = args.input_datums(0).clone();
57+
let mask = args.input_datums(1).clone().into_bool();
58+
match (input, mask) {
59+
(Datum::Scalar(input), BoolDatum::Scalar(mask)) => {
60+
let mut result = input;
61+
result.mask_validity(mask.value().vortex_expect("mask is non-nullable"));
62+
Ok(Datum::Scalar(result))
63+
}
64+
(Datum::Scalar(input), BoolDatum::Vector(mask)) => {
65+
let mut result = input.repeat(args.row_count()).freeze();
66+
result.mask_validity(&Mask::from(mask.into_bits()));
67+
Ok(Datum::Vector(result))
68+
}
69+
(Datum::Vector(input_array), BoolDatum::Scalar(mask)) => {
70+
let mut result = input_array;
71+
result.mask_validity(&Mask::new(
72+
args.row_count(),
73+
mask.value().vortex_expect("mask is non-nullable"),
74+
));
75+
Ok(Datum::Vector(result))
76+
}
77+
(Datum::Vector(input_array), BoolDatum::Vector(mask)) => {
78+
let mut result = input_array;
79+
result.mask_validity(&Mask::from(mask.into_bits()));
80+
Ok(Datum::Vector(result))
81+
}
82+
}
83+
}
84+
}

vortex-array/src/scalar_fns/mod.rs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,78 @@ use crate::ArrayRef;
1717
use crate::arrays::ScalarFnArrayExt;
1818
use crate::expr::Expression;
1919
use crate::expr::ScalarFnExprExt;
20+
use crate::expr::functions::EmptyOptions;
2021

2122
pub mod cast;
2223
pub mod is_null;
24+
pub mod mask;
2325
pub mod not;
2426

2527
/// A collection of built-in scalar functions that can be applied to expressions or arrays.
26-
pub trait BuiltinScalarFns: Sized {
28+
pub trait ExprBuiltins: Sized {
2729
/// Cast to the given data type.
28-
fn cast(&self, dtype: DType) -> VortexResult<Self>;
30+
fn cast(&self, dtype: DType) -> VortexResult<Expression>;
31+
32+
/// Is null check.
33+
fn is_null(&self) -> VortexResult<Expression>;
34+
35+
/// Boolean negation.
36+
fn not(&self) -> VortexResult<Expression>;
37+
38+
/// Mask the expression using the given boolean mask.
39+
/// The resulting expression's validity is the intersection of the original expression's
40+
/// validity.
41+
fn mask(&self, mask: Expression) -> VortexResult<Expression>;
2942
}
3043

31-
impl BuiltinScalarFns for Expression {
44+
impl ExprBuiltins for Expression {
3245
fn cast(&self, dtype: DType) -> VortexResult<Expression> {
3346
cast::CastFn.try_new_expr(dtype, [self.clone()])
3447
}
48+
49+
fn is_null(&self) -> VortexResult<Expression> {
50+
is_null::IsNullFn.try_new_expr(EmptyOptions, [self.clone()])
51+
}
52+
53+
fn not(&self) -> VortexResult<Expression> {
54+
not::NotFn.try_new_expr(EmptyOptions, [self.clone()])
55+
}
56+
57+
fn mask(&self, mask: Expression) -> VortexResult<Expression> {
58+
mask::MaskFn.try_new_expr(EmptyOptions, [self.clone(), mask])
59+
}
3560
}
3661

37-
impl BuiltinScalarFns for ArrayRef {
38-
fn cast(&self, dtype: DType) -> VortexResult<Self> {
62+
pub trait ArrayBuiltins: Sized {
63+
/// Cast to the given data type.
64+
fn cast(&self, dtype: DType) -> VortexResult<ArrayRef>;
65+
66+
/// Is null check.
67+
fn is_null(&self) -> VortexResult<ArrayRef>;
68+
69+
/// Boolean negation.
70+
fn not(&self) -> VortexResult<ArrayRef>;
71+
72+
/// Mask the array using the given boolean mask.
73+
/// The resulting array's validity is the intersection of the original array's validity
74+
/// and the mask's validity.
75+
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef>;
76+
}
77+
78+
impl ArrayBuiltins for ArrayRef {
79+
fn cast(&self, dtype: DType) -> VortexResult<ArrayRef> {
3980
cast::CastFn.try_new_array(self.len(), dtype, [self.clone()])
4081
}
82+
83+
fn is_null(&self) -> VortexResult<ArrayRef> {
84+
is_null::IsNullFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
85+
}
86+
87+
fn not(&self) -> VortexResult<ArrayRef> {
88+
not::NotFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
89+
}
90+
91+
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef> {
92+
mask::MaskFn.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()])
93+
}
4194
}

vortex-array/src/scalar_fns/not/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl VTable for NotFn {
2828
}
2929

3030
fn arity(&self, _: &Self::Options) -> Arity {
31-
Arity::Fixed(1)
31+
Arity::Exact(1)
3232
}
3333

3434
fn null_handling(&self, _options: &Self::Options) -> NullHandling {

vortex-compute/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@ pub mod comparison;
1515
pub mod expand;
1616
pub mod filter;
1717
pub mod logical;
18-
pub mod mask;
1918
pub mod take;

0 commit comments

Comments
 (0)