|
| 1 | +use std::any::Any; |
| 2 | +use std::fmt::Display; |
| 3 | +use std::sync::Arc; |
| 4 | + |
| 5 | +use itertools::Itertools as _; |
| 6 | +use vortex_array::array::StructArray; |
| 7 | +use vortex_array::validity::Validity; |
| 8 | +use vortex_array::{ArrayData, IntoArrayData}; |
| 9 | +use vortex_dtype::FieldNames; |
| 10 | +use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; |
| 11 | + |
| 12 | +use crate::{ExprRef, VortexExpr}; |
| 13 | + |
| 14 | +/// Pack zero or more expressions into a structure with named fields. |
| 15 | +/// |
| 16 | +/// # Examples |
| 17 | +/// |
| 18 | +/// ``` |
| 19 | +/// use vortex_array::IntoArrayData; |
| 20 | +/// use vortex_array::compute::scalar_at; |
| 21 | +/// use vortex_buffer::buffer; |
| 22 | +/// use vortex_expr::{Pack, Identity, VortexExpr}; |
| 23 | +/// use vortex_scalar::Scalar; |
| 24 | +/// |
| 25 | +/// let example = Pack::try_new_expr( |
| 26 | +/// ["x".into(), "x copy".into(), "second x copy".into()].into(), |
| 27 | +/// vec![Identity::new_expr(), Identity::new_expr(), Identity::new_expr()], |
| 28 | +/// ).unwrap(); |
| 29 | +/// let packed = example.evaluate(&buffer![100, 110, 200].into_array()).unwrap(); |
| 30 | +/// let x_copy = packed |
| 31 | +/// .as_struct_array() |
| 32 | +/// .unwrap() |
| 33 | +/// .field_by_name("x copy") |
| 34 | +/// .unwrap(); |
| 35 | +/// assert_eq!(scalar_at(&x_copy, 0).unwrap(), Scalar::from(100)); |
| 36 | +/// assert_eq!(scalar_at(&x_copy, 1).unwrap(), Scalar::from(110)); |
| 37 | +/// assert_eq!(scalar_at(&x_copy, 2).unwrap(), Scalar::from(200)); |
| 38 | +/// ``` |
| 39 | +/// |
| 40 | +#[derive(Debug, Clone)] |
| 41 | +pub struct Pack { |
| 42 | + names: FieldNames, |
| 43 | + values: Vec<ExprRef>, |
| 44 | +} |
| 45 | + |
| 46 | +impl Pack { |
| 47 | + pub fn try_new_expr(names: FieldNames, values: Vec<ExprRef>) -> VortexResult<Arc<Self>> { |
| 48 | + if names.len() != values.len() { |
| 49 | + vortex_bail!("length mismatch {} {}", names.len(), values.len()); |
| 50 | + } |
| 51 | + Ok(Arc::new(Pack { names, values })) |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +impl PartialEq<dyn Any> for Pack { |
| 56 | + fn eq(&self, other: &dyn Any) -> bool { |
| 57 | + other.downcast_ref::<Pack>().is_some_and(|other_pack| { |
| 58 | + self.names == other_pack.names |
| 59 | + && self |
| 60 | + .values |
| 61 | + .iter() |
| 62 | + .zip(other_pack.values.iter()) |
| 63 | + .all(|(x, y)| x.eq(y)) |
| 64 | + }) |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +impl Display for Pack { |
| 69 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 70 | + let mut f = f.debug_struct("Pack"); |
| 71 | + for (name, value) in self.names.iter().zip_eq(self.values.iter()) { |
| 72 | + f.field(name, value); |
| 73 | + } |
| 74 | + f.finish() |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +impl VortexExpr for Pack { |
| 79 | + fn as_any(&self) -> &dyn Any { |
| 80 | + self |
| 81 | + } |
| 82 | + |
| 83 | + fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> { |
| 84 | + let len = batch.len(); |
| 85 | + let value_arrays = self |
| 86 | + .values |
| 87 | + .iter() |
| 88 | + .map(|value_expr| value_expr.evaluate(batch)) |
| 89 | + .process_results(|it| it.collect::<Vec<_>>())?; |
| 90 | + StructArray::try_new(self.names.clone(), value_arrays, len, Validity::NonNullable) |
| 91 | + .map(IntoArrayData::into_array) |
| 92 | + } |
| 93 | + |
| 94 | + fn children(&self) -> Vec<&ExprRef> { |
| 95 | + self.values.iter().collect() |
| 96 | + } |
| 97 | + |
| 98 | + fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef { |
| 99 | + assert_eq!(children.len(), self.values.len()); |
| 100 | + Self::try_new_expr(self.names.clone(), children) |
| 101 | + .vortex_expect("children are known to have the same length as names") |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +impl PartialEq<Pack> for Pack { |
| 106 | + fn eq(&self, other: &Pack) -> bool { |
| 107 | + self.names == other.names && self.values == other.values |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +impl Eq for Pack {} |
| 112 | + |
| 113 | +#[cfg(test)] |
| 114 | +mod tests { |
| 115 | + use std::sync::Arc; |
| 116 | + |
| 117 | + use vortex_array::array::{PrimitiveArray, StructArray}; |
| 118 | + use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant as _}; |
| 119 | + use vortex_buffer::buffer; |
| 120 | + use vortex_dtype::{Field, FieldNames}; |
| 121 | + use vortex_error::{vortex_bail, vortex_err, VortexResult}; |
| 122 | + |
| 123 | + use crate::{col, Column, Pack, VortexExpr}; |
| 124 | + |
| 125 | + fn test_array() -> StructArray { |
| 126 | + StructArray::from_fields(&[ |
| 127 | + ("a", buffer![0, 1, 2].into_array()), |
| 128 | + ("b", buffer![4, 5, 6].into_array()), |
| 129 | + ]) |
| 130 | + .unwrap() |
| 131 | + } |
| 132 | + |
| 133 | + fn primitive_field(array: &ArrayData, field_path: &[&str]) -> VortexResult<PrimitiveArray> { |
| 134 | + let mut field_path = field_path.iter(); |
| 135 | + |
| 136 | + let Some(field) = field_path.next() else { |
| 137 | + vortex_bail!("empty field path"); |
| 138 | + }; |
| 139 | + |
| 140 | + let mut array = array |
| 141 | + .as_struct_array() |
| 142 | + .ok_or_else(|| vortex_err!("expected a struct"))? |
| 143 | + .field_by_name(field) |
| 144 | + .ok_or_else(|| vortex_err!("expected field to exist: {}", field))?; |
| 145 | + |
| 146 | + for field in field_path { |
| 147 | + array = array |
| 148 | + .as_struct_array() |
| 149 | + .ok_or_else(|| vortex_err!("expected a struct"))? |
| 150 | + .field_by_name(field) |
| 151 | + .ok_or_else(|| vortex_err!("expected field to exist: {}", field))?; |
| 152 | + } |
| 153 | + Ok(array.into_primitive().unwrap()) |
| 154 | + } |
| 155 | + |
| 156 | + #[test] |
| 157 | + pub fn test_empty_pack() { |
| 158 | + let expr = Pack::try_new_expr(Arc::new([]), Vec::new()).unwrap(); |
| 159 | + |
| 160 | + let test_array = test_array().into_array(); |
| 161 | + let actual_array = expr.evaluate(&test_array).unwrap(); |
| 162 | + assert_eq!(actual_array.len(), test_array.len()); |
| 163 | + assert!(actual_array.as_struct_array().unwrap().nfields() == 0); |
| 164 | + } |
| 165 | + |
| 166 | + #[test] |
| 167 | + pub fn test_simple_pack() { |
| 168 | + let expr = Pack::try_new_expr( |
| 169 | + ["one".into(), "two".into(), "three".into()].into(), |
| 170 | + vec![col("a"), col("b"), col("a")], |
| 171 | + ) |
| 172 | + .unwrap(); |
| 173 | + |
| 174 | + let actual_array = expr.evaluate(test_array().as_ref()).unwrap(); |
| 175 | + let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into(); |
| 176 | + assert_eq!( |
| 177 | + actual_array.as_struct_array().unwrap().names(), |
| 178 | + &expected_names |
| 179 | + ); |
| 180 | + |
| 181 | + assert_eq!( |
| 182 | + primitive_field(&actual_array, &["one"]) |
| 183 | + .unwrap() |
| 184 | + .as_slice::<i32>(), |
| 185 | + [0, 1, 2] |
| 186 | + ); |
| 187 | + assert_eq!( |
| 188 | + primitive_field(&actual_array, &["two"]) |
| 189 | + .unwrap() |
| 190 | + .as_slice::<i32>(), |
| 191 | + [4, 5, 6] |
| 192 | + ); |
| 193 | + assert_eq!( |
| 194 | + primitive_field(&actual_array, &["three"]) |
| 195 | + .unwrap() |
| 196 | + .as_slice::<i32>(), |
| 197 | + [0, 1, 2] |
| 198 | + ); |
| 199 | + } |
| 200 | + |
| 201 | + #[test] |
| 202 | + pub fn test_nested_pack() { |
| 203 | + let expr = Pack::try_new_expr( |
| 204 | + ["one".into(), "two".into(), "three".into()].into(), |
| 205 | + vec![ |
| 206 | + Column::new_expr(Field::from("a")), |
| 207 | + Pack::try_new_expr( |
| 208 | + ["two_one".into(), "two_two".into()].into(), |
| 209 | + vec![ |
| 210 | + Column::new_expr(Field::from("b")), |
| 211 | + Column::new_expr(Field::from("b")), |
| 212 | + ], |
| 213 | + ) |
| 214 | + .unwrap(), |
| 215 | + Column::new_expr(Field::from("a")), |
| 216 | + ], |
| 217 | + ) |
| 218 | + .unwrap(); |
| 219 | + |
| 220 | + let actual_array = expr.evaluate(test_array().as_ref()).unwrap(); |
| 221 | + let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into(); |
| 222 | + assert_eq!( |
| 223 | + actual_array.as_struct_array().unwrap().names(), |
| 224 | + &expected_names |
| 225 | + ); |
| 226 | + |
| 227 | + assert_eq!( |
| 228 | + primitive_field(&actual_array, &["one"]) |
| 229 | + .unwrap() |
| 230 | + .as_slice::<i32>(), |
| 231 | + [0, 1, 2] |
| 232 | + ); |
| 233 | + assert_eq!( |
| 234 | + primitive_field(&actual_array, &["two", "two_one"]) |
| 235 | + .unwrap() |
| 236 | + .as_slice::<i32>(), |
| 237 | + [4, 5, 6] |
| 238 | + ); |
| 239 | + assert_eq!( |
| 240 | + primitive_field(&actual_array, &["two", "two_two"]) |
| 241 | + .unwrap() |
| 242 | + .as_slice::<i32>(), |
| 243 | + [4, 5, 6] |
| 244 | + ); |
| 245 | + assert_eq!( |
| 246 | + primitive_field(&actual_array, &["three"]) |
| 247 | + .unwrap() |
| 248 | + .as_slice::<i32>(), |
| 249 | + [0, 1, 2] |
| 250 | + ); |
| 251 | + } |
| 252 | +} |
0 commit comments