|
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | 4 | use std::collections::BTreeSet; |
5 | | -use std::ops::Range; |
| 5 | +use std::ops::{Not, Range}; |
6 | 6 | use std::sync::Arc; |
7 | 7 |
|
8 | 8 | use futures::try_join; |
9 | 9 | use itertools::Itertools; |
| 10 | +use vortex_array::arrays::StructArray; |
10 | 11 | use vortex_array::stats::Precision; |
11 | | -use vortex_array::{MaskFuture, ToCanonical}; |
| 12 | +use vortex_array::vtable::ValidityHelper; |
| 13 | +use vortex_array::{ArrayRef, IntoArray, MaskFuture, ToCanonical}; |
12 | 14 | use vortex_dtype::{DType, FieldMask, FieldName, Nullability, StructFields}; |
13 | 15 | use vortex_error::{VortexExpect, VortexResult, vortex_err}; |
14 | 16 | use vortex_expr::transform::immediate_access::annotate_scope_access; |
15 | 17 | use vortex_expr::transform::{ |
16 | 18 | PartitionedExpr, partition, replace, replace_root_fields, simplify_typed, |
17 | 19 | }; |
18 | | -use vortex_expr::{ExactExpr, ExprRef, col, root}; |
| 20 | +use vortex_expr::{ExactExpr, ExprRef, MergeVTable, PackVTable, col, root}; |
19 | 21 | use vortex_mask::Mask; |
20 | 22 | use vortex_utils::aliases::dash_map::DashMap; |
21 | 23 | use vortex_utils::aliases::hash_map::HashMap; |
@@ -264,45 +266,63 @@ impl LayoutReader for StructReader { |
264 | 266 | &self, |
265 | 267 | row_range: &Range<u64>, |
266 | 268 | expr: &ExprRef, |
267 | | - mask: MaskFuture, |
| 269 | + mask_fut: MaskFuture, |
268 | 270 | ) -> VortexResult<ArrayFuture> { |
269 | 271 | let validity_fut = self |
270 | 272 | .validity()? |
271 | | - .map(|reader| reader.projection_evaluation(row_range, &root(), mask.clone())) |
| 273 | + .map(|reader| reader.projection_evaluation(row_range, &root(), mask_fut.clone())) |
272 | 274 | .transpose()?; |
273 | 275 |
|
274 | | - println!( |
275 | | - "StructReader::projection_eval on layout\n\t\tvalidity: {}\n\t\tdtype: {}\n\t\texpr: {expr}", |
276 | | - validity_fut.is_some(), |
277 | | - self.dtype(), |
278 | | - ); |
279 | | - |
280 | 276 | // Partition the expression into expressions that can be evaluated over individual fields |
281 | | - let array_future = match &self.partition_expr(expr.clone()) { |
282 | | - Partitioned::Single(name, partition) => self |
283 | | - .field_reader(name)? |
284 | | - .projection_evaluation(row_range, partition, mask)?, |
| 277 | + let (projected, is_pack_merge) = match &self.partition_expr(expr.clone()) { |
| 278 | + Partitioned::Single(name, partition) => ( |
| 279 | + self.field_reader(name)? |
| 280 | + .projection_evaluation(row_range, partition, mask_fut)?, |
| 281 | + partition.is::<PackVTable>() || partition.is::<MergeVTable>(), |
| 282 | + ), |
285 | 283 |
|
286 | 284 | Partitioned::Multi(partitioned) => { |
287 | 285 | // Apply the validity to each internal field instead. |
288 | | - partitioned |
289 | | - .clone() |
290 | | - .into_array_future(mask, |name, expr, mask| { |
291 | | - self.field_reader(name)? |
292 | | - .projection_evaluation(row_range, expr, mask) |
293 | | - })? |
| 286 | + ( |
| 287 | + partitioned |
| 288 | + .clone() |
| 289 | + .into_array_future(mask_fut, |name, expr, mask| { |
| 290 | + self.field_reader(name)? |
| 291 | + .projection_evaluation(row_range, expr, mask) |
| 292 | + })?, |
| 293 | + partitioned.root.is::<PackVTable>() || partitioned.root.is::<MergeVTable>(), |
| 294 | + ) |
294 | 295 | } |
295 | 296 | }; |
296 | 297 |
|
297 | 298 | Ok(Box::pin(async move { |
298 | 299 | if let Some(validity_fut) = validity_fut { |
299 | | - let (validity, array) = try_join!(validity_fut, array_future)?; |
300 | | - vortex_array::compute::mask( |
301 | | - array.as_ref(), |
302 | | - &Mask::from_buffer(!validity.to_bool().boolean_buffer()), |
303 | | - ) |
| 300 | + let (array, validity) = try_join!(projected, validity_fut)?; |
| 301 | + let mask = Mask::from_buffer(validity.to_bool().bit_buffer().not()); |
| 302 | + |
| 303 | + // If root expression was a pack, then we apply the validity to each child field |
| 304 | + if is_pack_merge { |
| 305 | + let struct_array = array.to_struct(); |
| 306 | + let masked_fields: Vec<ArrayRef> = struct_array |
| 307 | + .fields() |
| 308 | + .iter() |
| 309 | + .map(|a| vortex_array::compute::mask(a.as_ref(), &mask)) |
| 310 | + .try_collect()?; |
| 311 | + |
| 312 | + Ok(StructArray::try_new( |
| 313 | + struct_array.names().clone(), |
| 314 | + masked_fields, |
| 315 | + struct_array.len(), |
| 316 | + struct_array.validity().clone(), |
| 317 | + )? |
| 318 | + .into_array()) |
| 319 | + } else { |
| 320 | + // If the root expression was not a pack or merge, e.g. if it's something like |
| 321 | + // a get_item, then we apply the validity directly to the result |
| 322 | + vortex_array::compute::mask(array.as_ref(), &mask) |
| 323 | + } |
304 | 324 | } else { |
305 | | - array_future.await |
| 325 | + projected.await |
306 | 326 | } |
307 | 327 | })) |
308 | 328 | } |
@@ -604,7 +624,14 @@ mod tests { |
604 | 624 | result.scalar_at(0).as_struct().field_by_idx(0).unwrap(), |
605 | 625 | Scalar::primitive(4, Nullability::Nullable) |
606 | 626 | ); |
607 | | - assert_eq!(result.scalar_at(1), Scalar::null(result.dtype().clone())); |
| 627 | + assert!( |
| 628 | + result |
| 629 | + .scalar_at(1) |
| 630 | + .as_struct() |
| 631 | + .field_by_idx(0) |
| 632 | + .unwrap() |
| 633 | + .is_null(), |
| 634 | + ); |
608 | 635 | assert_eq!( |
609 | 636 | result.scalar_at(2).as_struct().field_by_idx(0).unwrap(), |
610 | 637 | Scalar::primitive(6, Nullability::Nullable) |
|
0 commit comments