|
1 | 1 | use async_trait::async_trait; |
2 | | -use vortex_array::ArrayData; |
| 2 | +use futures::future::try_join_all; |
| 3 | +use itertools::Itertools; |
| 4 | +use vortex_array::array::StructArray; |
| 5 | +use vortex_array::validity::Validity; |
| 6 | +use vortex_array::{ArrayData, IntoArrayData}; |
3 | 7 | use vortex_error::VortexResult; |
| 8 | +use vortex_expr::transform::partition::partition; |
4 | 9 | use vortex_expr::ExprRef; |
5 | 10 | use vortex_scan::RowMask; |
6 | 11 |
|
7 | | -use crate::layouts::struct_::reader::StructScan; |
| 12 | +use crate::layouts::struct_::reader::StructReader; |
8 | 13 | use crate::ExprEvaluator; |
9 | 14 |
|
10 | 15 | #[async_trait(?Send)] |
11 | | -impl ExprEvaluator for StructScan { |
12 | | - async fn evaluate_expr(&self, _row_mask: RowMask, _expr: ExprRef) -> VortexResult<ArrayData> { |
13 | | - todo!() |
| 16 | +impl ExprEvaluator for StructReader { |
| 17 | + async fn evaluate_expr(&self, row_mask: RowMask, expr: ExprRef) -> VortexResult<ArrayData> { |
| 18 | + // Partition the expression into expressions that can be evaluated over individual fields |
| 19 | + let partitioned = partition(expr, self.struct_dtype())?; |
| 20 | + let field_readers: Vec<_> = partitioned |
| 21 | + .partitions |
| 22 | + .iter() |
| 23 | + .map(|partition| self.child(&partition.field)) |
| 24 | + .try_collect()?; |
| 25 | + |
| 26 | + let arrays = try_join_all( |
| 27 | + field_readers |
| 28 | + .iter() |
| 29 | + .zip_eq(partitioned.partitions.iter()) |
| 30 | + .map(|(reader, partition)| { |
| 31 | + reader.evaluate_expr(row_mask.clone(), partition.expr.clone()) |
| 32 | + }), |
| 33 | + ) |
| 34 | + .await?; |
| 35 | + |
| 36 | + let row_count = row_mask.true_count(); |
| 37 | + debug_assert!(arrays.iter().all(|a| a.len() == row_count)); |
| 38 | + |
| 39 | + let root_scope = StructArray::try_new( |
| 40 | + partitioned |
| 41 | + .partitions |
| 42 | + .iter() |
| 43 | + .map(|p| p.name.clone()) |
| 44 | + .collect::<Vec<_>>() |
| 45 | + .into(), |
| 46 | + arrays, |
| 47 | + row_count, |
| 48 | + Validity::NonNullable, |
| 49 | + )? |
| 50 | + .into_array(); |
| 51 | + |
| 52 | + // Recombine the partitioned expressions into a single expression |
| 53 | + partitioned.root.evaluate(&root_scope) |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +#[cfg(test)] |
| 58 | +mod tests { |
| 59 | + use std::sync::Arc; |
| 60 | + |
| 61 | + use futures::executor::block_on; |
| 62 | + use vortex_array::array::StructArray; |
| 63 | + use vortex_array::compute::FilterMask; |
| 64 | + use vortex_array::{IntoArrayData, IntoArrayVariant}; |
| 65 | + use vortex_buffer::buffer; |
| 66 | + use vortex_dtype::PType::I32; |
| 67 | + use vortex_dtype::{DType, Nullability, StructDType}; |
| 68 | + use vortex_expr::{get_item, gt, ident}; |
| 69 | + use vortex_scan::RowMask; |
| 70 | + |
| 71 | + use crate::layouts::flat::writer::FlatLayoutWriter; |
| 72 | + use crate::layouts::struct_::writer::StructLayoutWriter; |
| 73 | + use crate::segments::test::TestSegments; |
| 74 | + use crate::strategies::LayoutWriterExt; |
| 75 | + use crate::LayoutData; |
| 76 | + |
| 77 | + /// Create a chunked layout with three chunks of primitive arrays. |
| 78 | + fn struct_layout() -> (Arc<TestSegments>, LayoutData) { |
| 79 | + let mut segments = TestSegments::default(); |
| 80 | + |
| 81 | + let layout = StructLayoutWriter::new( |
| 82 | + DType::Struct( |
| 83 | + StructDType::new( |
| 84 | + vec!["a".into(), "b".into(), "c".into()].into(), |
| 85 | + vec![I32.into(), I32.into(), I32.into()], |
| 86 | + ), |
| 87 | + Nullability::NonNullable, |
| 88 | + ), |
| 89 | + vec![ |
| 90 | + Box::new(FlatLayoutWriter::new(I32.into())), |
| 91 | + Box::new(FlatLayoutWriter::new(I32.into())), |
| 92 | + Box::new(FlatLayoutWriter::new(I32.into())), |
| 93 | + ], |
| 94 | + ) |
| 95 | + .push_all( |
| 96 | + &mut segments, |
| 97 | + [StructArray::from_fields( |
| 98 | + [ |
| 99 | + ("a", buffer![7, 2, 3].into_array()), |
| 100 | + ("b", buffer![4, 5, 6].into_array()), |
| 101 | + ("c", buffer![4, 5, 6].into_array()), |
| 102 | + ] |
| 103 | + .as_slice(), |
| 104 | + ) |
| 105 | + .map(IntoArrayData::into_array)], |
| 106 | + ) |
| 107 | + .unwrap(); |
| 108 | + (Arc::new(segments), layout) |
| 109 | + } |
| 110 | + |
| 111 | + #[test] |
| 112 | + fn test_struct_layout() { |
| 113 | + let (segments, layout) = struct_layout(); |
| 114 | + |
| 115 | + let reader = layout.reader(segments, Default::default()).unwrap(); |
| 116 | + let expr = gt(get_item("a", ident()), get_item("b", ident())); |
| 117 | + let result = |
| 118 | + block_on(reader.evaluate_expr(RowMask::new_valid_between(0, 3), expr)).unwrap(); |
| 119 | + assert_eq!( |
| 120 | + vec![true, false, false], |
| 121 | + result |
| 122 | + .into_bool() |
| 123 | + .unwrap() |
| 124 | + .boolean_buffer() |
| 125 | + .iter() |
| 126 | + .collect::<Vec<_>>() |
| 127 | + ); |
| 128 | + } |
| 129 | + |
| 130 | + #[test] |
| 131 | + fn test_struct_layout_row_mask() { |
| 132 | + let (segments, layout) = struct_layout(); |
| 133 | + |
| 134 | + let reader = layout.reader(segments, Default::default()).unwrap(); |
| 135 | + let expr = gt(get_item("a", ident()), get_item("b", ident())); |
| 136 | + let result = block_on(reader.evaluate_expr( |
| 137 | + // Take rows 0 and 1, skip row 2, and anything after that |
| 138 | + RowMask::new(FilterMask::from_iter([true, true, false]), 0), |
| 139 | + expr, |
| 140 | + )) |
| 141 | + .unwrap(); |
| 142 | + |
| 143 | + assert_eq!(result.len(), 2); |
| 144 | + |
| 145 | + assert_eq!( |
| 146 | + vec![true, false], |
| 147 | + result |
| 148 | + .into_bool() |
| 149 | + .unwrap() |
| 150 | + .boolean_buffer() |
| 151 | + .iter() |
| 152 | + .collect::<Vec<_>>() |
| 153 | + ); |
14 | 154 | } |
15 | 155 | } |
0 commit comments