|
3 | 3 |
|
4 | 4 | use std::fmt::Debug; |
5 | 5 |
|
| 6 | +use itertools::Itertools as _; |
| 7 | +use num_traits::NumCast; |
6 | 8 | use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray}; |
7 | 9 | use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar}; |
8 | 10 | use vortex_array::patches::Patches; |
9 | 11 | use vortex_array::stats::{ArrayStats, StatsSetRef}; |
10 | 12 | use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable}; |
11 | 13 | use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable}; |
12 | 14 | use vortex_buffer::Buffer; |
13 | | -use vortex_dtype::{DType, Nullability, match_each_integer_ptype}; |
| 15 | +use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype}; |
14 | 16 | use vortex_error::{VortexExpect as _, VortexResult, vortex_bail}; |
15 | 17 | use vortex_mask::{AllOr, Mask}; |
16 | 18 | use vortex_scalar::Scalar; |
@@ -268,48 +270,60 @@ impl ValidityVTable<SparseVTable> for SparseVTable { |
268 | 270 |
|
269 | 271 | #[allow(clippy::unnecessary_fallible_conversions)] |
270 | 272 | fn validity_mask(array: &SparseArray) -> VortexResult<Mask> { |
271 | | - let indices = array.patches().indices().to_primitive()?; |
| 273 | + let fill_is_valid = array.fill_scalar().is_valid(); |
| 274 | + let values_validity = array.patches().values().validity_mask()?; |
| 275 | + let len = array.len(); |
272 | 276 |
|
273 | | - if array.fill_scalar().is_null() { |
274 | | - // If we have a null fill value, then we set each patch value to true. |
275 | | - let mut buffer = BooleanBufferBuilder::new(array.len()); |
276 | | - // TODO(ngates): use vortex-buffer::BitBufferMut when it exists. |
277 | | - buffer.append_n(array.len(), false); |
278 | | - |
279 | | - match_each_integer_ptype!(indices.ptype(), |I| { |
280 | | - indices.as_slice::<I>().iter().for_each(|&index| { |
281 | | - buffer.set_bit( |
282 | | - usize::try_from(index).vortex_expect("Failed to cast to usize") |
283 | | - - array.patches().offset(), |
284 | | - true, |
285 | | - ); |
286 | | - }); |
287 | | - }); |
288 | | - |
289 | | - return Ok(Mask::from_buffer(buffer.finish())); |
| 277 | + if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid { |
| 278 | + return Ok(Mask::AllTrue(len)); |
| 279 | + } |
| 280 | + if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid { |
| 281 | + return Ok(Mask::AllFalse(len)); |
290 | 282 | } |
291 | 283 |
|
292 | | - // If the fill_value is non-null, then the validity is based on the validity of the |
293 | | - // patch values. |
294 | | - let mut buffer = BooleanBufferBuilder::new(array.len()); |
295 | | - buffer.append_n(array.len(), true); |
| 284 | + // TODO(ngates): use vortex-buffer::BitBufferMut when it exists. |
| 285 | + let mut is_valid_buffer = BooleanBufferBuilder::new(len); |
| 286 | + is_valid_buffer.append_n(len, fill_is_valid); |
| 287 | + |
| 288 | + let indices = array.patches().indices().to_primitive()?; |
| 289 | + let index_offset = array.patches().offset(); |
296 | 290 |
|
297 | | - let values_validity = array.patches().values().validity_mask()?; |
298 | 291 | match_each_integer_ptype!(indices.ptype(), |I| { |
299 | | - indices |
300 | | - .as_slice::<I>() |
301 | | - .iter() |
302 | | - .enumerate() |
303 | | - .for_each(|(patch_idx, &index)| { |
304 | | - buffer.set_bit( |
305 | | - usize::try_from(index).vortex_expect("Failed to cast to usize") |
306 | | - - array.patches().offset(), |
307 | | - values_validity.value(patch_idx), |
308 | | - ); |
309 | | - }) |
| 292 | + let indices = indices.as_slice::<I>(); |
| 293 | + patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity); |
310 | 294 | }); |
311 | 295 |
|
312 | | - Ok(Mask::from_buffer(buffer.finish())) |
| 296 | + Ok(Mask::from_buffer(is_valid_buffer.finish())) |
| 297 | + } |
| 298 | +} |
| 299 | + |
| 300 | +fn patch_validity<I: NativePType>( |
| 301 | + is_valid_buffer: &mut BooleanBufferBuilder, |
| 302 | + indices: &[I], |
| 303 | + index_offset: usize, |
| 304 | + values_validity: Mask, |
| 305 | +) { |
| 306 | + let indices = indices.iter().map(|index| { |
| 307 | + let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize"); |
| 308 | + index - index_offset |
| 309 | + }); |
| 310 | + match values_validity { |
| 311 | + Mask::AllTrue(_) => { |
| 312 | + for index in indices { |
| 313 | + is_valid_buffer.set_bit(index, true); |
| 314 | + } |
| 315 | + } |
| 316 | + Mask::AllFalse(_) => { |
| 317 | + for index in indices { |
| 318 | + is_valid_buffer.set_bit(index, false); |
| 319 | + } |
| 320 | + } |
| 321 | + Mask::Values(mask_values) => { |
| 322 | + let is_valid = mask_values.boolean_buffer().iter(); |
| 323 | + for (index, is_valid) in indices.zip_eq(is_valid) { |
| 324 | + is_valid_buffer.set_bit(index, is_valid); |
| 325 | + } |
| 326 | + } |
313 | 327 | } |
314 | 328 | } |
315 | 329 |
|
@@ -519,4 +533,18 @@ mod test { |
519 | 533 | vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4] |
520 | 534 | ); |
521 | 535 | } |
| 536 | + |
| 537 | + #[test] |
| 538 | + fn validity_mask_includes_null_values_when_fill_is_null() { |
| 539 | + let indices = buffer![0u8, 2, 4, 6, 8].into_array(); |
| 540 | + let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)]) |
| 541 | + .into_array(); |
| 542 | + let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap(); |
| 543 | + let actual = array.validity_mask().unwrap(); |
| 544 | + let expected = Mask::from_iter([ |
| 545 | + true, false, true, false, false, false, false, false, true, false, |
| 546 | + ]); |
| 547 | + |
| 548 | + assert_eq!(actual, expected); |
| 549 | + } |
522 | 550 | } |
0 commit comments