Skip to content

Commit 2502010

Browse files
committed
add MaskExpr
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 1c7b746 commit 2502010

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

vortex-expr/src/exprs/mask.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::fmt::Formatter;
5+
6+
use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata, ToCanonical};
7+
use vortex_dtype::DType;
8+
use vortex_error::{VortexResult, vortex_ensure};
9+
use vortex_mask::Mask;
10+
11+
use crate::display::{DisplayAs, DisplayFormat};
12+
use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, Scope, VTable, vtable};
13+
14+
vtable!(Mask);
15+
16+
#[allow(clippy::derived_hash_with_manual_eq)]
17+
#[derive(Clone, Debug, Hash, Eq)]
18+
pub struct MaskExpr {
19+
/// The target array to mask
20+
pub target: ExprRef,
21+
/// An expression that yields a boolean array for the masking operation.
22+
/// True values will be set to null, false values will be set to non-null
23+
pub mask: ExprRef,
24+
}
25+
26+
impl PartialEq for MaskExpr {
27+
fn eq(&self, other: &Self) -> bool {
28+
self.target.eq(&other.target) && self.mask.eq(&other.mask)
29+
}
30+
}
31+
32+
impl MaskExpr {
33+
/// Create a new `MaskExpr` against the provided `target` and `mask`.
34+
///
35+
/// The `target` is an expression that evaluates to any array type.
36+
///
37+
/// `mask` must evaluate to a non-nullable `Bool` array that is the same length as the `target`.
38+
/// All `true` values will set the result to `null`, and `false` values will preserve the
39+
/// corresponding value from `target`.
40+
pub fn new(target: ExprRef, mask: ExprRef) -> Self {
41+
Self { target, mask }
42+
}
43+
44+
pub fn target(&self) -> &ExprRef {
45+
&self.target
46+
}
47+
48+
pub fn mask(&self) -> &ExprRef {
49+
&self.mask
50+
}
51+
}
52+
53+
impl DisplayAs for MaskExpr {
54+
fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
55+
match df {
56+
DisplayFormat::Compact => {
57+
write!(f, "mask({}, {})", self.target, self.mask)
58+
}
59+
DisplayFormat::Tree => {
60+
write!(f, "Mask")
61+
}
62+
}
63+
}
64+
65+
fn child_names(&self) -> Option<Vec<String>> {
66+
Some(vec!["target".to_string(), "mask".to_string()])
67+
}
68+
}
69+
70+
pub struct MaskExprEncoding;
71+
72+
impl VTable for MaskVTable {
73+
type Expr = MaskExpr;
74+
type Encoding = MaskExprEncoding;
75+
type Metadata = EmptyMetadata;
76+
77+
fn id(_encoding: &Self::Encoding) -> ExprId {
78+
ExprId::new_ref("vortex.mask")
79+
}
80+
81+
fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
82+
ExprEncodingRef::new_ref(MaskExprEncoding.as_ref())
83+
}
84+
85+
fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
86+
Some(EmptyMetadata)
87+
}
88+
89+
fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
90+
vec![&expr.target, &expr.target]
91+
}
92+
93+
fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
94+
vortex_ensure!(
95+
children.len() == 2,
96+
"cannot build MaskExpr: expected children [target, mask], received {} children",
97+
children.len()
98+
);
99+
100+
let target = children[0].clone();
101+
let mask = children[1].clone();
102+
103+
Ok(MaskExpr { target, mask })
104+
}
105+
106+
fn build(
107+
_encoding: &Self::Encoding,
108+
_: &<Self::Metadata as DeserializeMetadata>::Output,
109+
children: Vec<ExprRef>,
110+
) -> VortexResult<Self::Expr> {
111+
vortex_ensure!(
112+
children.len() == 2,
113+
"MaskExpr expected children [target, mask], received {} children",
114+
children.len()
115+
);
116+
117+
let target = children[0].clone();
118+
let mask = children[1].clone();
119+
120+
Ok(MaskExpr { target, mask })
121+
}
122+
123+
fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
124+
let target = expr.target.evaluate(scope)?;
125+
let mask = expr.mask.evaluate(scope)?;
126+
vortex_array::compute::mask(
127+
target.as_ref(),
128+
&Mask::from_buffer(mask.to_bool().bit_buffer().clone()),
129+
)
130+
}
131+
132+
fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
133+
// Mask operation always returns a nullable result
134+
Ok(expr.target.return_dtype(scope)?.as_nullable())
135+
}
136+
}
137+
138+
impl AnalysisExpr for MaskExpr {}
139+
140+
#[cfg(test)]
141+
mod tests {
142+
use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
143+
use vortex_array::{ArrayEq, IntoArray, Precision};
144+
145+
use crate::exprs::mask::MaskExpr;
146+
use crate::{Scope, get_item, is_null, root};
147+
148+
#[test]
149+
fn test_mask_primitive() {
150+
let root_array =
151+
PrimitiveArray::from_option_iter([Some(1), Some(2), None, Some(4)]).into_array();
152+
let scope = Scope::new(root_array.clone());
153+
154+
// Mask(root, IsNull(root)) should match root
155+
let expr = MaskExpr::new(root(), is_null(root()));
156+
let result = expr.evaluate(&scope).unwrap();
157+
assert!(result.array_eq(&root_array, Precision::Value));
158+
}
159+
160+
#[test]
161+
fn test_mask_struct() {
162+
// Perform a mask operation onto a nested result using the struct array instead.
163+
let a = PrimitiveArray::from_option_iter([Some(1), Some(2), None, Some(4)]).into_array();
164+
let b = BoolArray::from_iter([false, true, false, true]).into_array();
165+
166+
let root_array = StructArray::from_fields(&[("a", a), ("b", b)])
167+
.unwrap()
168+
.into_array();
169+
170+
let scope = Scope::new(root_array.clone());
171+
172+
// mask a using the b array.
173+
let expr = MaskExpr::new(get_item("a", root()), get_item("b", root()));
174+
175+
let result = expr.evaluate(&scope).unwrap();
176+
177+
assert_eq!(result.scalar_at(0), Some(1i32).into());
178+
assert_eq!(result.scalar_at(1), Option::<i32>::None.into());
179+
assert_eq!(result.scalar_at(2), Option::<i32>::None.into());
180+
assert_eq!(result.scalar_at(3), Option::<i32>::None.into());
181+
}
182+
}

vortex-expr/src/exprs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub(crate) mod is_null;
1111
pub(crate) mod like;
1212
pub(crate) mod list_contains;
1313
pub(crate) mod literal;
14+
mod mask;
1415
pub(crate) mod merge;
1516
pub(crate) mod not;
1617
pub(crate) mod operators;

0 commit comments

Comments
 (0)