Skip to content

Commit 46573e5

Browse files
committed
mask nulls into projection result
Signed-off-by: Andrew Duffy <[email protected]>
1 parent e3a59b7 commit 46573e5

File tree

1 file changed

+117
-7
lines changed

1 file changed

+117
-7
lines changed

vortex-layout/src/layouts/struct_/reader.rs

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ use std::collections::BTreeSet;
55
use std::ops::Range;
66
use std::sync::Arc;
77

8+
use futures::try_join;
89
use itertools::Itertools;
9-
use vortex_array::MaskFuture;
10-
use vortex_dtype::{DType, FieldMask, FieldName, StructFields};
10+
use vortex_array::stats::Precision;
11+
use vortex_array::{MaskFuture, ToCanonical};
12+
use vortex_dtype::{DType, FieldMask, FieldName, Nullability, StructFields};
1113
use vortex_error::{VortexExpect, VortexResult, vortex_err};
1214
use vortex_expr::transform::immediate_access::annotate_scope_access;
1315
use vortex_expr::transform::{
@@ -90,13 +92,36 @@ impl StructReader {
9092

9193
/// Return the child reader for the field, by index.
9294
fn child_by_idx(&self, idx: usize) -> VortexResult<&LayoutReaderRef> {
95+
let child_index = if self.dtype().is_nullable() {
96+
idx + 1
97+
} else {
98+
idx
99+
};
100+
93101
let field_dtype = self
94102
.struct_fields()
95103
.field_by_index(idx)
96104
.ok_or_else(|| vortex_err!("Missing field {idx}"))?;
97105
let name = &self.struct_fields().names()[idx];
98-
self.lazy_children
99-
.get(idx, &field_dtype, &format!("{}.{}", self.name, name).into())
106+
self.lazy_children.get(
107+
child_index,
108+
&field_dtype,
109+
&format!("{}.{}", self.name, name).into(),
110+
)
111+
}
112+
113+
/// Return the reader for the struct validity, if present
114+
fn validity(&self) -> VortexResult<Option<&LayoutReaderRef>> {
115+
self.dtype()
116+
.is_nullable()
117+
.then(|| {
118+
self.lazy_children.get(
119+
0,
120+
&DType::Bool(Nullability::NonNullable),
121+
&"validity".into(),
122+
)
123+
})
124+
.transpose()
100125
}
101126

102127
/// Utility for partitioning an expression over the fields of a struct.
@@ -155,7 +180,9 @@ impl StructReader {
155180
/// some cost and just delegate to the child reader directly.
156181
#[derive(Clone)]
157182
enum Partitioned {
183+
/// An expression which only operates over a single field
158184
Single(FieldName, ExprRef),
185+
/// An expression which operates over multiple fields
159186
Multi(Arc<PartitionedExpr<FieldName>>),
160187
}
161188

@@ -234,8 +261,13 @@ impl LayoutReader for StructReader {
234261
expr: &ExprRef,
235262
mask: MaskFuture,
236263
) -> VortexResult<ArrayFuture> {
264+
let validity_fut = self
265+
.validity()?
266+
.map(|reader| reader.projection_evaluation(row_range, &root(), mask.clone()))
267+
.transpose()?;
268+
237269
// Partition the expression into expressions that can be evaluated over individual fields
238-
match &self.partition_expr(expr.clone()) {
270+
let projection_fut = match &self.partition_expr(expr.clone()) {
239271
Partitioned::Single(name, partition) => self
240272
.child(name)?
241273
.projection_evaluation(row_range, partition, mask),
@@ -247,7 +279,19 @@ impl LayoutReader for StructReader {
247279
.projection_evaluation(row_range, expr, mask)
248280
})
249281
}
250-
}
282+
}?;
283+
284+
Ok(Box::pin(async move {
285+
if let Some(validity_fut) = validity_fut {
286+
let (validity, projection) = try_join!(validity_fut, projection_fut)?;
287+
vortex_array::compute::mask(
288+
projection.as_ref(),
289+
&Mask::from_buffer(!validity.to_bool().boolean_buffer()),
290+
)
291+
} else {
292+
projection_fut.await
293+
}
294+
}))
251295
}
252296
}
253297

@@ -257,13 +301,16 @@ mod tests {
257301

258302
use itertools::Itertools;
259303
use rstest::{fixture, rstest};
260-
use vortex_array::arrays::StructArray;
304+
use vortex_array::arrays::{BoolArray, StructArray};
305+
use vortex_array::validity::Validity;
261306
use vortex_array::{Array, ArrayContext, IntoArray, MaskFuture, ToCanonical};
262307
use vortex_buffer::buffer;
263308
use vortex_dtype::Nullability::NonNullable;
309+
use vortex_dtype::{DType, Nullability, PType};
264310
use vortex_expr::{col, eq, get_item, gt, lit, or, pack, root};
265311
use vortex_io::runtime::single::block_on;
266312
use vortex_mask::Mask;
313+
use vortex_scalar::Scalar;
267314

268315
use crate::layouts::flat::writer::FlatLayoutStrategy;
269316
use crate::layouts::struct_::writer::StructStrategy;
@@ -304,6 +351,39 @@ mod tests {
304351
(segments, layout)
305352
}
306353

354+
#[fixture]
355+
/// Create a chunked layout with three chunks of primitive arrays.
356+
fn null_struct_layout() -> (Arc<dyn SegmentSource>, LayoutRef) {
357+
let ctx = ArrayContext::empty();
358+
let segments = Arc::new(TestSegments::default());
359+
let (ptr, eof) = SequenceId::root().split();
360+
let strategy =
361+
StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
362+
let layout = block_on(|handle| {
363+
strategy.write_stream(
364+
ctx,
365+
segments.clone(),
366+
StructArray::try_from_iter_with_validity(
367+
[
368+
("a", buffer![7, 2, 3].into_array()),
369+
("b", buffer![4, 5, 6].into_array()),
370+
("c", buffer![4, 5, 6].into_array()),
371+
],
372+
Validity::Array(BoolArray::from_iter([false, true, true]).into_array()),
373+
)
374+
.unwrap()
375+
.into_array()
376+
.to_array_stream()
377+
.sequenced(ptr),
378+
eof,
379+
handle,
380+
)
381+
})
382+
.unwrap();
383+
384+
(segments, layout)
385+
}
386+
307387
#[rstest]
308388
fn test_struct_layout_or(
309389
#[from(struct_layout)] (segments, layout): (Arc<dyn SegmentSource>, LayoutRef),
@@ -411,4 +491,34 @@ mod tests {
411491
[4, 5].as_slice()
412492
);
413493
}
494+
495+
#[rstest]
496+
fn test_struct_layout_nulls(
497+
#[from(null_struct_layout)] (segments, layout): (Arc<dyn SegmentSource>, LayoutRef),
498+
) {
499+
// Read the layout source from the top.
500+
let reader = layout.new_reader("".into(), segments).unwrap();
501+
let expr = get_item("a", root());
502+
let project = reader
503+
.projection_evaluation(&(0..3), &expr, MaskFuture::new_true(3))
504+
.unwrap();
505+
506+
let result = block_on(move |_| project).unwrap();
507+
// Result should be nullable primitive array
508+
assert_eq!(
509+
result.dtype(),
510+
&DType::Primitive(PType::I32, Nullability::Nullable)
511+
);
512+
513+
assert_eq!(result.scalar_at(0), Scalar::null(result.dtype().clone()),);
514+
515+
assert_eq!(
516+
result.scalar_at(1),
517+
Scalar::primitive(2i32, Nullability::Nullable),
518+
);
519+
assert_eq!(
520+
result.scalar_at(2),
521+
Scalar::primitive(3i32, Nullability::Nullable),
522+
);
523+
}
414524
}

0 commit comments

Comments
 (0)