Skip to content

Commit dcbdb4e

Browse files
committed
MaskFn
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 4d21588 commit dcbdb4e

File tree

6 files changed

+69
-17
lines changed

6 files changed

+69
-17
lines changed

vortex-array/src/arrays/masked/vtable/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ mod operations;
77
mod validity;
88

99
use vortex_buffer::BufferHandle;
10-
use vortex_compute::mask::MaskValidity;
1110
use vortex_dtype::DType;
1211
use vortex_error::VortexResult;
1312
use vortex_error::vortex_bail;
1413
use vortex_vector::Vector;
14+
use vortex_vector::VectorOps;
1515

1616
use crate::ArrayBufferVisitor;
1717
use crate::ArrayChildVisitor;
@@ -106,8 +106,9 @@ impl VTable for MaskedVTable {
106106
}
107107

108108
fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
109-
let vector = array.child().batch_execute(ctx)?;
110-
Ok(MaskValidity::mask_validity(vector, &array.validity_mask()))
109+
let mut vector = array.child().batch_execute(ctx)?;
110+
vector.mask_validity(&array.validity_mask());
111+
Ok(vector)
111112
}
112113
}
113114

vortex-array/src/expr/exprs/get_item/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use std::fmt::Formatter;
77
use std::ops::Not;
88

99
use prost::Message;
10-
use vortex_compute::mask::MaskValidity;
1110
use vortex_dtype::DType;
1211
use vortex_dtype::FieldName;
1312
use vortex_dtype::FieldPath;
@@ -145,8 +144,8 @@ impl VTable for GetItem {
145144
.into_struct();
146145

147146
// We must intersect the validity with that of the parent struct
148-
let field = struct_vector.fields()[field_idx].clone();
149-
let field = MaskValidity::mask_validity(field, struct_vector.validity());
147+
let mut field = struct_vector.fields()[field_idx].clone();
148+
field.mask_validity(struct_vector.validity());
150149

151150
Ok(field)
152151
}

vortex-array/src/scalar_fns/cast/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ mod test {
4141
use crate::arrays::ConstantArray;
4242
use crate::arrays::ConstantVTable;
4343
use crate::optimizer::ArrayOptimizer;
44-
use crate::scalar_fns::BuiltinScalarFns;
44+
use crate::scalar_fns::ArrayBuiltins;
4545

4646
#[test]
4747
fn test_same_dtype() -> VortexResult<()> {

vortex-array/src/scalar_fns/cast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::expr::functions::FunctionId;
2121
use crate::expr::functions::NullHandling;
2222
use crate::expr::functions::VTable;
2323
use crate::expr::stats::Stat;
24-
use crate::scalar_fns::BuiltinScalarFns;
24+
use crate::scalar_fns::ExprBuiltins;
2525

2626
pub struct CastFn;
2727
impl VTable for CastFn {
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl VTable for MaskFn {
6060
let mask = args.input_datums(1).clone().into_bool();
6161
match (input, mask) {
6262
(Datum::Scalar(input), BoolDatum::Scalar(mask)) => {
63-
let mut result = input.clone();
63+
let mut result = input;
6464
result.mask_validity(mask.value().vortex_expect("mask is non-nullable"));
6565
Ok(Datum::Scalar(result))
6666
}
@@ -70,15 +70,15 @@ impl VTable for MaskFn {
7070
Ok(Datum::Vector(result))
7171
}
7272
(Datum::Vector(input_array), BoolDatum::Scalar(mask)) => {
73-
let mut result = input_array.clone();
73+
let mut result = input_array;
7474
result.mask_validity(&Mask::new(
7575
args.row_count(),
7676
mask.value().vortex_expect("mask is non-nullable"),
7777
));
7878
Ok(Datum::Vector(result))
7979
}
8080
(Datum::Vector(input_array), BoolDatum::Vector(mask)) => {
81-
let mut result = input_array.clone();
81+
let mut result = input_array;
8282
result.mask_validity(&Mask::from(mask.into_bits()));
8383
Ok(Datum::Vector(result))
8484
}

vortex-array/src/scalar_fns/mod.rs

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,78 @@ use crate::ArrayRef;
1717
use crate::arrays::ScalarFnArrayExt;
1818
use crate::expr::Expression;
1919
use crate::expr::ScalarFnExprExt;
20+
use crate::expr::functions::EmptyOptions;
2021

2122
pub mod cast;
2223
pub mod is_null;
23-
mod mask;
24+
pub mod mask;
2425
pub mod not;
2526

2627
/// A collection of built-in scalar functions that can be applied to expressions or arrays.
27-
pub trait BuiltinScalarFns: Sized {
28+
pub trait ExprBuiltins: Sized {
2829
/// Cast to the given data type.
29-
fn cast(&self, dtype: DType) -> VortexResult<Self>;
30+
fn cast(&self, dtype: DType) -> VortexResult<Expression>;
31+
32+
/// Is null check.
33+
fn is_null(&self) -> VortexResult<Expression>;
34+
35+
/// Boolean negation.
36+
fn not(&self) -> VortexResult<Expression>;
37+
38+
/// Mask the expression using the given boolean mask.
39+
/// The resulting expression's validity is the intersection of the original expression's
40+
/// validity.
41+
fn mask(&self, mask: Expression) -> VortexResult<Expression>;
3042
}
3143

32-
impl BuiltinScalarFns for Expression {
44+
impl ExprBuiltins for Expression {
3345
fn cast(&self, dtype: DType) -> VortexResult<Expression> {
3446
cast::CastFn.try_new_expr(dtype, [self.clone()])
3547
}
48+
49+
fn is_null(&self) -> VortexResult<Expression> {
50+
is_null::IsNullFn.try_new_expr(EmptyOptions, [self.clone()])
51+
}
52+
53+
fn not(&self) -> VortexResult<Expression> {
54+
not::NotFn.try_new_expr(EmptyOptions, [self.clone()])
55+
}
56+
57+
fn mask(&self, mask: Expression) -> VortexResult<Expression> {
58+
mask::MaskFn.try_new_expr(EmptyOptions, [self.clone(), mask])
59+
}
3660
}
3761

38-
impl BuiltinScalarFns for ArrayRef {
39-
fn cast(&self, dtype: DType) -> VortexResult<Self> {
62+
pub trait ArrayBuiltins: Sized {
63+
/// Cast to the given data type.
64+
fn cast(&self, dtype: DType) -> VortexResult<ArrayRef>;
65+
66+
/// Is null check.
67+
fn is_null(&self) -> VortexResult<ArrayRef>;
68+
69+
/// Boolean negation.
70+
fn not(&self) -> VortexResult<ArrayRef>;
71+
72+
/// Mask the array using the given boolean mask.
73+
/// The resulting array's validity is the intersection of the original array's validity
74+
/// and the mask's validity.
75+
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef>;
76+
}
77+
78+
impl ArrayBuiltins for ArrayRef {
79+
fn cast(&self, dtype: DType) -> VortexResult<ArrayRef> {
4080
cast::CastFn.try_new_array(self.len(), dtype, [self.clone()])
4181
}
82+
83+
fn is_null(&self) -> VortexResult<ArrayRef> {
84+
is_null::IsNullFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
85+
}
86+
87+
fn not(&self) -> VortexResult<ArrayRef> {
88+
not::NotFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
89+
}
90+
91+
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef> {
92+
mask::MaskFn.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()])
93+
}
4294
}

0 commit comments

Comments
 (0)