|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | | -use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult, min_max, take}; |
5 | | -use vortex_array::register_kernel; |
| 4 | +use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult, mask, min_max}; |
| 5 | +use vortex_array::{Array as _, ToCanonical, register_kernel}; |
| 6 | +use vortex_buffer::BitBufferMut; |
| 7 | +use vortex_dtype::match_each_unsigned_integer_ptype; |
6 | 8 | use vortex_error::VortexResult; |
| 9 | +use vortex_mask::Mask; |
7 | 10 |
|
8 | 11 | use crate::{DictArray, DictVTable}; |
9 | 12 |
|
10 | 13 | impl MinMaxKernel for DictVTable { |
11 | 14 | fn min_max(&self, array: &DictArray) -> VortexResult<Option<MinMaxResult>> { |
12 | | - min_max(&take(array.values(), array.codes())?) |
| 15 | + let codes_validity = array.codes().validity_mask(); |
| 16 | + if codes_validity.all_false() { |
| 17 | + return Ok(None); |
| 18 | + } |
| 19 | + |
| 20 | + let codes_primitive = array.codes().to_primitive(); |
| 21 | + let values_len = array.values().len(); |
| 22 | + match_each_unsigned_integer_ptype!(codes_primitive.ptype(), |P| { |
| 23 | + codes_validity.iter_bools(|validity_iter| { |
| 24 | + // mask() sets values to null where the mask is true, so we start |
| 25 | + // with a fully-set bit buffer. |
| 26 | + let mut unreferenced = BitBufferMut::new_set(values_len); |
| 27 | + for (&code, is_valid) in codes_primitive.as_slice::<P>().iter().zip(validity_iter) { |
| 28 | + if is_valid { |
| 29 | + // SAFETY: code is valid, so it must be in range. |
| 30 | + #[allow(clippy::cast_possible_truncation)] |
| 31 | + unsafe { |
| 32 | + unreferenced.unset_unchecked(code as usize); |
| 33 | + } |
| 34 | + } |
| 35 | + } |
| 36 | + |
| 37 | + let unreferenced_mask = Mask::from_buffer(unreferenced.freeze()); |
| 38 | + min_max(&mask(array.values(), &unreferenced_mask)?) |
| 39 | + }) |
| 40 | + }) |
13 | 41 | } |
14 | 42 | } |
15 | 43 |
|
16 | 44 | register_kernel!(MinMaxKernelAdapter(DictVTable).lift()); |
| 45 | + |
| 46 | +#[cfg(test)] |
| 47 | +mod tests { |
| 48 | + use rstest::rstest; |
| 49 | + use vortex_array::arrays::PrimitiveArray; |
| 50 | + use vortex_array::compute::min_max; |
| 51 | + use vortex_array::{Array, IntoArray}; |
| 52 | + use vortex_buffer::buffer; |
| 53 | + |
| 54 | + use crate::DictArray; |
| 55 | + use crate::builders::dict_encode; |
| 56 | + |
| 57 | + fn assert_min_max(array: &dyn Array, expected: Option<(i32, i32)>) { |
| 58 | + match (min_max(array).unwrap(), expected) { |
| 59 | + (Some(result), Some((expected_min, expected_max))) => { |
| 60 | + assert_eq!(i32::try_from(result.min).unwrap(), expected_min); |
| 61 | + assert_eq!(i32::try_from(result.max).unwrap(), expected_max); |
| 62 | + } |
| 63 | + (None, None) => {} |
| 64 | + (got, expected) => panic!( |
| 65 | + "min_max mismatch: expected {:?}, got {:?}", |
| 66 | + expected, |
| 67 | + got.as_ref().map(|r| ( |
| 68 | + i32::try_from(r.min.clone()).ok(), |
| 69 | + i32::try_from(r.max.clone()).ok() |
| 70 | + )) |
| 71 | + ), |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + #[rstest] |
| 76 | + #[case::covering( |
| 77 | + DictArray::try_new( |
| 78 | + buffer![0u32, 1, 2, 3, 0, 1].into_array(), |
| 79 | + buffer![10i32, 20, 30, 40].into_array(), |
| 80 | + ).unwrap(), |
| 81 | + (10, 40) |
| 82 | + )] |
| 83 | + #[case::non_covering_duplicates( |
| 84 | + DictArray::try_new( |
| 85 | + buffer![1u32, 1, 1, 3, 3].into_array(), |
| 86 | + buffer![1i32, 2, 3, 4, 5].into_array(), |
| 87 | + ).unwrap(), |
| 88 | + (2, 4) |
| 89 | + )] |
| 90 | + // Non-covering: codes with gaps |
| 91 | + #[case::non_covering_gaps( |
| 92 | + DictArray::try_new( |
| 93 | + buffer![0u32, 2, 4].into_array(), |
| 94 | + buffer![1i32, 2, 3, 4, 5].into_array(), |
| 95 | + ).unwrap(), |
| 96 | + (1, 5) |
| 97 | + )] |
| 98 | + #[case::single(dict_encode(&buffer![42i32].into_array()).unwrap(), (42, 42))] |
| 99 | + #[case::nullable_codes( |
| 100 | + DictArray::try_new( |
| 101 | + PrimitiveArray::from_option_iter([Some(0u32), None, Some(1), Some(2)]).into_array(), |
| 102 | + buffer![10i32, 20, 30].into_array(), |
| 103 | + ).unwrap(), |
| 104 | + (10, 30) |
| 105 | + )] |
| 106 | + #[case::nullable_values( |
| 107 | + dict_encode( |
| 108 | + PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref() |
| 109 | + ).unwrap(), |
| 110 | + (1, 2) |
| 111 | + )] |
| 112 | + fn test_min_max(#[case] dict: DictArray, #[case] expected: (i32, i32)) { |
| 113 | + assert_min_max(dict.as_ref(), Some(expected)); |
| 114 | + } |
| 115 | + |
| 116 | + #[test] |
| 117 | + fn test_sliced_dict() { |
| 118 | + let reference = PrimitiveArray::from_iter([1, 5, 10, 50, 100]); |
| 119 | + let dict = dict_encode(reference.as_ref()).unwrap(); |
| 120 | + let sliced = dict.slice(1..3); |
| 121 | + assert_min_max(&sliced, Some((5, 10))); |
| 122 | + } |
| 123 | + |
| 124 | + #[rstest] |
| 125 | + #[case::empty( |
| 126 | + DictArray::try_new( |
| 127 | + PrimitiveArray::from_iter(Vec::<u32>::new()).into_array(), |
| 128 | + buffer![10i32, 20, 30].into_array(), |
| 129 | + ).unwrap() |
| 130 | + )] |
| 131 | + #[case::all_null_codes( |
| 132 | + DictArray::try_new( |
| 133 | + PrimitiveArray::from_option_iter([Option::<u32>::None, None, None]).into_array(), |
| 134 | + buffer![10i32, 20, 30].into_array(), |
| 135 | + ).unwrap() |
| 136 | + )] |
| 137 | + fn test_min_max_none(#[case] dict: DictArray) { |
| 138 | + assert_min_max(dict.as_ref(), None); |
| 139 | + } |
| 140 | +} |
0 commit comments