|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
| 4 | +use itertools::Itertools; |
| 5 | +use num_traits::Zero; |
4 | 6 | use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult, min_max, take}; |
5 | | -use vortex_array::register_kernel; |
| 7 | +use vortex_array::{Array as _, ToCanonical, register_kernel}; |
| 8 | +use vortex_dtype::match_each_unsigned_integer_ptype; |
6 | 9 | use vortex_error::VortexResult; |
| 10 | +use vortex_mask::Mask; |
7 | 11 |
|
8 | 12 | use crate::{DictArray, DictVTable}; |
9 | 13 |
|
10 | 14 | impl MinMaxKernel for DictVTable { |
11 | 15 | fn min_max(&self, array: &DictArray) -> VortexResult<Option<MinMaxResult>> { |
12 | | - min_max(&take(array.values(), array.codes())?) |
| 16 | + let Mask::AllTrue(_) = array.codes().validity_mask() else { |
| 17 | + return min_max(&take(array.values(), array.codes().as_ref())?); |
| 18 | + }; |
| 19 | + |
| 20 | + let codes_primitive = array.codes().to_primitive(); |
| 21 | + let values_len = array.values().len(); |
| 22 | + match_each_unsigned_integer_ptype!(codes_primitive.ptype(), |P| { |
| 23 | + let mut codes = codes_primitive |
| 24 | + .as_slice::<P>() |
| 25 | + .iter() |
| 26 | + // First dedup of consecutive codes before collecting should |
| 27 | + // make this operation a little cheaper. |
| 28 | + .dedup() |
| 29 | + .copied() |
| 30 | + .collect::<Vec<_>>(); |
| 31 | + codes.sort_unstable(); |
| 32 | + codes.dedup(); |
| 33 | + |
| 34 | + let max_expected = P::try_from(values_len - 1).ok(); |
| 35 | + if codes.len() == values_len |
| 36 | + && codes.first() == Some(&P::zero()) |
| 37 | + && codes.last() == max_expected.as_ref() |
| 38 | + { |
| 39 | + // Codes fully cover all values, compute min_max directly on |
| 40 | + // values. |
| 41 | + return min_max(array.values()); |
| 42 | + } |
| 43 | + |
| 44 | + // Codes do not fully cover the values, we need to take only |
| 45 | + // referenced values, but can do so using only the unique codes, |
| 46 | + // avoiding a full materialization. |
| 47 | + min_max(&take(array.values(), array.codes().as_ref())?) |
| 48 | + }) |
13 | 49 | } |
14 | 50 | } |
15 | 51 |
|
16 | 52 | register_kernel!(MinMaxKernelAdapter(DictVTable).lift()); |
| 53 | + |
| 54 | +#[cfg(test)] |
| 55 | +mod tests { |
| 56 | + use rstest::rstest; |
| 57 | + use vortex_array::arrays::PrimitiveArray; |
| 58 | + use vortex_array::compute::min_max; |
| 59 | + use vortex_array::{Array, IntoArray}; |
| 60 | + use vortex_buffer::buffer; |
| 61 | + |
| 62 | + use crate::DictArray; |
| 63 | + use crate::builders::dict_encode; |
| 64 | + |
| 65 | + fn assert_min_max(array: &dyn Array, expected: Option<(i32, i32)>) { |
| 66 | + match (min_max(array).unwrap(), expected) { |
| 67 | + (Some(result), Some((expected_min, expected_max))) => { |
| 68 | + assert_eq!(i32::try_from(result.min).unwrap(), expected_min); |
| 69 | + assert_eq!(i32::try_from(result.max).unwrap(), expected_max); |
| 70 | + } |
| 71 | + (None, None) => {} |
| 72 | + (got, expected) => panic!( |
| 73 | + "min_max mismatch: expected {:?}, got {:?}", |
| 74 | + expected, |
| 75 | + got.as_ref().map(|r| ( |
| 76 | + i32::try_from(r.min.clone()).ok(), |
| 77 | + i32::try_from(r.max.clone()).ok() |
| 78 | + )) |
| 79 | + ), |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + #[rstest] |
| 84 | + #[case::covering( |
| 85 | + DictArray::try_new( |
| 86 | + buffer![0u32, 1, 2, 3, 0, 1].into_array(), |
| 87 | + buffer![10i32, 20, 30, 40].into_array(), |
| 88 | + ).unwrap(), |
| 89 | + (10, 40) |
| 90 | + )] |
| 91 | + #[case::non_covering_duplicates( |
| 92 | + DictArray::try_new( |
| 93 | + buffer![1u32, 1, 1, 3, 3].into_array(), |
| 94 | + buffer![1i32, 2, 3, 4, 5].into_array(), |
| 95 | + ).unwrap(), |
| 96 | + (2, 4) |
| 97 | + )] |
| 98 | + // Non-covering: codes with gaps |
| 99 | + #[case::non_covering_gaps( |
| 100 | + DictArray::try_new( |
| 101 | + buffer![0u32, 2, 4].into_array(), |
| 102 | + buffer![1i32, 2, 3, 4, 5].into_array(), |
| 103 | + ).unwrap(), |
| 104 | + (1, 5) |
| 105 | + )] |
| 106 | + #[case::single(dict_encode(&buffer![42i32].into_array()).unwrap(), (42, 42))] |
| 107 | + #[case::nullable_codes( |
| 108 | + DictArray::try_new( |
| 109 | + PrimitiveArray::from_option_iter([Some(0u32), None, Some(1), Some(2)]).into_array(), |
| 110 | + buffer![10i32, 20, 30].into_array(), |
| 111 | + ).unwrap(), |
| 112 | + (10, 30) |
| 113 | + )] |
| 114 | + #[case::nullable_values( |
| 115 | + dict_encode( |
| 116 | + PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref() |
| 117 | + ).unwrap(), |
| 118 | + (1, 2) |
| 119 | + )] |
| 120 | + fn test_min_max(#[case] dict: DictArray, #[case] expected: (i32, i32)) { |
| 121 | + assert_min_max(dict.as_ref(), Some(expected)); |
| 122 | + } |
| 123 | + |
| 124 | + #[test] |
| 125 | + fn test_sliced_dict() { |
| 126 | + let reference = PrimitiveArray::from_iter([1, 5, 10, 50, 100]); |
| 127 | + let dict = dict_encode(reference.as_ref()).unwrap(); |
| 128 | + let sliced = dict.slice(1..3); |
| 129 | + assert_min_max(&sliced, Some((5, 10))); |
| 130 | + } |
| 131 | + |
| 132 | + #[rstest] |
| 133 | + #[case::empty( |
| 134 | + DictArray::try_new( |
| 135 | + PrimitiveArray::from_iter(Vec::<u32>::new()).into_array(), |
| 136 | + buffer![10i32, 20, 30].into_array(), |
| 137 | + ).unwrap() |
| 138 | + )] |
| 139 | + #[case::all_null_codes( |
| 140 | + DictArray::try_new( |
| 141 | + PrimitiveArray::from_option_iter([Option::<u32>::None, None, None]).into_array(), |
| 142 | + buffer![10i32, 20, 30].into_array(), |
| 143 | + ).unwrap() |
| 144 | + )] |
| 145 | + fn test_min_max_none(#[case] dict: DictArray) { |
| 146 | + assert_min_max(dict.as_ref(), None); |
| 147 | + } |
| 148 | +} |
0 commit comments