|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | | -use vortex_array::compute::mask as mask_fn; |
5 | | -use vortex_array::{ArrayRef, Canonical}; |
6 | | -use vortex_error::VortexResult; |
7 | | -use vortex_mask::Mask; |
| 4 | +use std::ops::Not; |
| 5 | +use std::sync::Arc; |
| 6 | + |
| 7 | +use vortex_array::arrays::{ |
| 8 | + BoolArray, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray, PrimitiveArray, |
| 9 | + StructArray, VarBinViewArray, |
| 10 | +}; |
| 11 | +use vortex_array::validity::Validity; |
| 12 | +use vortex_array::vtable::ValidityHelper; |
| 13 | +use vortex_array::{ArrayRef, Canonical, IntoArray, ToCanonical}; |
| 14 | +use vortex_dtype::ExtDType; |
| 15 | +use vortex_error::{VortexResult, VortexUnwrap}; |
| 16 | +use vortex_mask::{AllOr, Mask}; |
| 17 | +use vortex_scalar::match_each_decimal_value_type; |
8 | 18 |
|
9 | 19 | /// Apply mask on the canonical form of the array to get a consistent baseline. |
| 20 | +/// This implementation manually applies the mask to each canonical type |
| 21 | +/// without using the mask_fn method, to serve as an independent baseline for testing. |
10 | 22 | pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<ArrayRef> { |
11 | | - // TODO(joe): replace with baseline not using canonical |
12 | | - mask_fn(canonical.as_ref(), mask) |
| 23 | + Ok(match canonical { |
| 24 | + Canonical::Null(array) => { |
| 25 | + // Null arrays are already all invalid, masking has no effect |
| 26 | + array.into_array() |
| 27 | + } |
| 28 | + Canonical::Bool(array) => { |
| 29 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 30 | + BoolArray::from_bit_buffer(array.bit_buffer().clone(), new_validity).into_array() |
| 31 | + } |
| 32 | + Canonical::Primitive(array) => { |
| 33 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 34 | + PrimitiveArray::from_byte_buffer( |
| 35 | + array.byte_buffer().clone(), |
| 36 | + array.ptype(), |
| 37 | + new_validity, |
| 38 | + ) |
| 39 | + .into_array() |
| 40 | + } |
| 41 | + Canonical::Decimal(array) => { |
| 42 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 43 | + match_each_decimal_value_type!(array.values_type(), |D| { |
| 44 | + DecimalArray::new(array.buffer::<D>(), array.decimal_dtype(), new_validity) |
| 45 | + .into_array() |
| 46 | + }) |
| 47 | + } |
| 48 | + Canonical::VarBinView(array) => { |
| 49 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 50 | + VarBinViewArray::new( |
| 51 | + array.views().clone(), |
| 52 | + array.buffers().clone(), |
| 53 | + array.dtype().as_nullable(), |
| 54 | + new_validity, |
| 55 | + ) |
| 56 | + .into_array() |
| 57 | + } |
| 58 | + Canonical::List(array) => { |
| 59 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 60 | + ListViewArray::try_new( |
| 61 | + array.elements().clone(), |
| 62 | + array.offsets().clone(), |
| 63 | + array.sizes().clone(), |
| 64 | + new_validity, |
| 65 | + ) |
| 66 | + .vortex_unwrap() |
| 67 | + .into_array() |
| 68 | + } |
| 69 | + Canonical::FixedSizeList(array) => { |
| 70 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 71 | + FixedSizeListArray::new( |
| 72 | + array.elements().clone(), |
| 73 | + array.list_size(), |
| 74 | + new_validity, |
| 75 | + array.len(), |
| 76 | + ) |
| 77 | + .into_array() |
| 78 | + } |
| 79 | + Canonical::Struct(array) => { |
| 80 | + let new_validity = apply_mask_to_validity(array.validity(), mask); |
| 81 | + StructArray::try_new_with_dtype( |
| 82 | + array.fields().clone(), |
| 83 | + array.struct_fields().clone(), |
| 84 | + array.len(), |
| 85 | + new_validity, |
| 86 | + ) |
| 87 | + .vortex_unwrap() |
| 88 | + .into_array() |
| 89 | + } |
| 90 | + Canonical::Extension(array) => { |
| 91 | + // Recursively mask the storage array |
| 92 | + let masked_storage = |
| 93 | + mask_canonical_array(array.storage().to_canonical(), mask).vortex_unwrap(); |
| 94 | + |
| 95 | + if masked_storage.dtype().nullability() |
| 96 | + == array.ext_dtype().storage_dtype().nullability() |
| 97 | + { |
| 98 | + ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array() |
| 99 | + } else { |
| 100 | + // The storage dtype changed (i.e., became nullable due to masking) |
| 101 | + let ext_dtype = Arc::new(ExtDType::new( |
| 102 | + array.ext_dtype().id().clone(), |
| 103 | + Arc::new(masked_storage.dtype().clone()), |
| 104 | + array.ext_dtype().metadata().cloned(), |
| 105 | + )); |
| 106 | + ExtensionArray::new(ext_dtype, masked_storage).into_array() |
| 107 | + } |
| 108 | + } |
| 109 | + }) |
| 110 | +} |
| 111 | + |
| 112 | +fn apply_mask_to_validity(validity: &Validity, mask: &Mask) -> Validity { |
| 113 | + match mask.bit_buffer() { |
| 114 | + AllOr::All => Validity::AllInvalid, |
| 115 | + AllOr::None => validity.clone(), |
| 116 | + AllOr::Some(make_invalid) => match validity { |
| 117 | + Validity::NonNullable | Validity::AllValid => { |
| 118 | + Validity::Array(BoolArray::from(make_invalid.not()).into_array()) |
| 119 | + } |
| 120 | + Validity::AllInvalid => Validity::AllInvalid, |
| 121 | + Validity::Array(is_valid) => { |
| 122 | + let is_valid = is_valid.to_bool(); |
| 123 | + let keep_valid = make_invalid.not(); |
| 124 | + Validity::from(is_valid.bit_buffer() & &keep_valid) |
| 125 | + } |
| 126 | + }, |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +#[cfg(test)] |
| 131 | +mod tests { |
| 132 | + use vortex_array::arrays::{ |
| 133 | + BoolArray, DecimalArray, FixedSizeListArray, ListViewArray, NullArray, PrimitiveArray, |
| 134 | + StructArray, VarBinViewArray, |
| 135 | + }; |
| 136 | + use vortex_array::{Array, IntoArray}; |
| 137 | + use vortex_dtype::{DecimalDType, FieldNames, Nullability}; |
| 138 | + use vortex_mask::Mask; |
| 139 | + use vortex_scalar::Scalar; |
| 140 | + |
| 141 | + use super::mask_canonical_array; |
| 142 | + |
| 143 | + #[test] |
| 144 | + fn test_mask_null_array() { |
| 145 | + let array = NullArray::new(5); |
| 146 | + let mask = Mask::from_iter([true, false, true, false, true]); |
| 147 | + |
| 148 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 149 | + |
| 150 | + assert_eq!(result.len(), 5); |
| 151 | + // All values should still be null |
| 152 | + for i in 0..5 { |
| 153 | + assert!(!result.is_valid(i)); |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + #[test] |
| 158 | + fn test_mask_bool_array() { |
| 159 | + let array = BoolArray::from_iter([true, false, true, false, true]); |
| 160 | + let mask = Mask::from_iter([true, false, false, true, false]); |
| 161 | + |
| 162 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 163 | + |
| 164 | + assert_eq!(result.len(), 5); |
| 165 | + assert!(!result.is_valid(0)); |
| 166 | + assert_eq!(result.scalar_at(1), Scalar::from(Some(false))); |
| 167 | + assert_eq!(result.scalar_at(2), Scalar::from(Some(true))); |
| 168 | + assert!(!result.is_valid(3)); |
| 169 | + assert_eq!(result.scalar_at(4), Scalar::from(Some(true))); |
| 170 | + } |
| 171 | + |
| 172 | + #[test] |
| 173 | + fn test_mask_primitive_array() { |
| 174 | + let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); |
| 175 | + let mask = Mask::from_iter([false, true, false, true, false]); |
| 176 | + |
| 177 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 178 | + |
| 179 | + assert_eq!(result.len(), 5); |
| 180 | + assert_eq!(result.scalar_at(0), Scalar::from(Some(1))); |
| 181 | + assert!(!result.is_valid(1)); |
| 182 | + assert_eq!(result.scalar_at(2), Scalar::from(Some(3))); |
| 183 | + assert!(!result.is_valid(3)); |
| 184 | + assert_eq!(result.scalar_at(4), Scalar::from(Some(5))); |
| 185 | + } |
| 186 | + |
| 187 | + #[test] |
| 188 | + fn test_mask_primitive_array_with_nulls() { |
| 189 | + let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]); |
| 190 | + let mask = Mask::from_iter([true, false, false, true, false]); |
| 191 | + |
| 192 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 193 | + |
| 194 | + assert_eq!(result.len(), 5); |
| 195 | + assert!(!result.is_valid(0)); |
| 196 | + assert!(!result.is_valid(1)); // was already null |
| 197 | + assert_eq!(result.scalar_at(2), Scalar::from(Some(3))); |
| 198 | + assert!(!result.is_valid(3)); |
| 199 | + assert!(!result.is_valid(4)); // was already null |
| 200 | + } |
| 201 | + |
| 202 | + #[test] |
| 203 | + fn test_mask_decimal_array() { |
| 204 | + let array = DecimalArray::from_option_iter( |
| 205 | + [Some(1i128), Some(2), Some(3), Some(4), Some(5)], |
| 206 | + DecimalDType::new(10, 2), |
| 207 | + ); |
| 208 | + let mask = Mask::from_iter([false, false, true, false, false]); |
| 209 | + |
| 210 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 211 | + |
| 212 | + assert_eq!(result.len(), 5); |
| 213 | + assert!(result.is_valid(0)); |
| 214 | + assert!(result.is_valid(1)); |
| 215 | + assert!(!result.is_valid(2)); |
| 216 | + assert!(result.is_valid(3)); |
| 217 | + assert!(result.is_valid(4)); |
| 218 | + } |
| 219 | + |
| 220 | + #[test] |
| 221 | + fn test_mask_varbinview_array() { |
| 222 | + let array = VarBinViewArray::from_iter_str(["one", "two", "three", "four", "five"]); |
| 223 | + let mask = Mask::from_iter([true, false, true, false, true]); |
| 224 | + |
| 225 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 226 | + |
| 227 | + assert_eq!(result.len(), 5); |
| 228 | + assert!(!result.is_valid(0)); |
| 229 | + assert_eq!( |
| 230 | + result.scalar_at(1), |
| 231 | + Scalar::utf8("two", Nullability::Nullable) |
| 232 | + ); |
| 233 | + assert!(!result.is_valid(2)); |
| 234 | + assert_eq!( |
| 235 | + result.scalar_at(3), |
| 236 | + Scalar::utf8("four", Nullability::Nullable) |
| 237 | + ); |
| 238 | + assert!(!result.is_valid(4)); |
| 239 | + } |
| 240 | + |
| 241 | + #[test] |
| 242 | + fn test_mask_list_array() { |
| 243 | + let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]).into_array(); |
| 244 | + let offsets = PrimitiveArray::from_iter([0i32, 2, 4]).into_array(); |
| 245 | + let sizes = PrimitiveArray::from_iter([2i32, 2, 2]).into_array(); |
| 246 | + let array = |
| 247 | + ListViewArray::try_new(elements, offsets, sizes, Nullability::NonNullable.into()) |
| 248 | + .unwrap(); |
| 249 | + |
| 250 | + let mask = Mask::from_iter([false, true, false]); |
| 251 | + |
| 252 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 253 | + |
| 254 | + assert_eq!(result.len(), 3); |
| 255 | + assert!(result.is_valid(0)); |
| 256 | + assert!(!result.is_valid(1)); |
| 257 | + assert!(result.is_valid(2)); |
| 258 | + } |
| 259 | + |
| 260 | + #[test] |
| 261 | + fn test_mask_fixed_size_list_array() { |
| 262 | + let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]).into_array(); |
| 263 | + let array = |
| 264 | + FixedSizeListArray::try_new(elements, 2, Nullability::NonNullable.into(), 3).unwrap(); |
| 265 | + |
| 266 | + let mask = Mask::from_iter([true, false, true]); |
| 267 | + |
| 268 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 269 | + |
| 270 | + assert_eq!(result.len(), 3); |
| 271 | + assert!(!result.is_valid(0)); |
| 272 | + assert!(result.is_valid(1)); |
| 273 | + assert!(!result.is_valid(2)); |
| 274 | + } |
| 275 | + |
| 276 | + #[test] |
| 277 | + fn test_mask_struct_array() { |
| 278 | + let field1 = PrimitiveArray::from_iter([1i32, 2, 3]).into_array(); |
| 279 | + let field2 = PrimitiveArray::from_iter([4i32, 5, 6]).into_array(); |
| 280 | + let fields = vec![field1, field2]; |
| 281 | + |
| 282 | + let array = StructArray::try_new( |
| 283 | + FieldNames::from(["a", "b"]), |
| 284 | + fields, |
| 285 | + 3, |
| 286 | + Nullability::NonNullable.into(), |
| 287 | + ) |
| 288 | + .unwrap(); |
| 289 | + |
| 290 | + let mask = Mask::from_iter([false, true, false]); |
| 291 | + |
| 292 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 293 | + |
| 294 | + assert_eq!(result.len(), 3); |
| 295 | + assert!(result.is_valid(0)); |
| 296 | + assert!(!result.is_valid(1)); |
| 297 | + assert!(result.is_valid(2)); |
| 298 | + } |
| 299 | + |
| 300 | + #[test] |
| 301 | + fn test_mask_all_true() { |
| 302 | + let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); |
| 303 | + let mask = Mask::AllTrue(5); |
| 304 | + |
| 305 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 306 | + |
| 307 | + assert_eq!(result.len(), 5); |
| 308 | + // All values should be masked out (null) |
| 309 | + for i in 0..5 { |
| 310 | + assert!(!result.is_valid(i)); |
| 311 | + } |
| 312 | + } |
| 313 | + |
| 314 | + #[test] |
| 315 | + fn test_mask_all_false() { |
| 316 | + let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); |
| 317 | + let mask = Mask::AllFalse(5); |
| 318 | + |
| 319 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 320 | + |
| 321 | + assert_eq!(result.len(), 5); |
| 322 | + // No values should be masked out |
| 323 | + for i in 0..5 { |
| 324 | + assert!(result.is_valid(i)); |
| 325 | + #[allow(clippy::cast_possible_truncation)] |
| 326 | + let expected = (i + 1) as i32; |
| 327 | + assert_eq!(result.scalar_at(i), Scalar::from(Some(expected))); |
| 328 | + } |
| 329 | + } |
| 330 | + |
| 331 | + #[test] |
| 332 | + fn test_mask_empty_array() { |
| 333 | + let array = PrimitiveArray::from_iter(Vec::<i32>::new()); |
| 334 | + let mask = Mask::AllFalse(0); |
| 335 | + |
| 336 | + let result = mask_canonical_array(array.to_canonical(), &mask).unwrap(); |
| 337 | + |
| 338 | + assert_eq!(result.len(), 0); |
| 339 | + } |
13 | 340 | } |
0 commit comments