|
1 | 1 | use std::fmt::{Debug, Display}; |
2 | 2 | use std::sync::Arc; |
3 | 3 |
|
| 4 | +use enum_iterator::all; |
4 | 5 | use serde::{Deserialize, Serialize}; |
5 | 6 | use vortex_dtype::{DType, ExtDType, ExtID}; |
6 | 7 | use vortex_error::{VortexExpect as _, VortexResult}; |
7 | 8 |
|
8 | 9 | use crate::array::visitor::{AcceptArrayVisitor, ArrayVisitor}; |
9 | 10 | use crate::encoding::ids; |
10 | | -use crate::stats::ArrayStatisticsCompute; |
| 11 | +use crate::stats::{ArrayStatistics as _, ArrayStatisticsCompute, Stat, StatsSet}; |
11 | 12 | use crate::validity::{ArrayValidity, LogicalValidity}; |
12 | 13 | use crate::variants::{ArrayVariants, ExtensionArrayTrait}; |
13 | 14 | use crate::{impl_encoding, Array, ArrayDType, ArrayTrait, Canonical, IntoCanonical}; |
@@ -93,5 +94,69 @@ impl AcceptArrayVisitor for ExtensionArray { |
93 | 94 | } |
94 | 95 |
|
95 | 96 | impl ArrayStatisticsCompute for ExtensionArray { |
96 | | - // TODO(ngates): pass through stats to the underlying and cast the scalars. |
| 97 | + fn compute_statistics(&self, stat: Stat) -> VortexResult<StatsSet> { |
| 98 | + let mut stats = self.storage().statistics().compute_all(&[stat])?; |
| 99 | + |
| 100 | + // for e.g., min/max, we want to cast to the extension array's dtype |
| 101 | + // for other stats, we don't need to change anything |
| 102 | + for stat in all::<Stat>().filter(|s| s.has_same_dtype_as_array()) { |
| 103 | + if let Some(value) = stats.get(stat) { |
| 104 | + stats.set(stat, value.cast(self.dtype())?); |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + Ok(stats) |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +#[cfg(test)] |
| 113 | +mod tests { |
| 114 | + use itertools::Itertools; |
| 115 | + use vortex_dtype::PType; |
| 116 | + use vortex_scalar::{PValue, Scalar, ScalarValue}; |
| 117 | + |
| 118 | + use super::*; |
| 119 | + use crate::array::PrimitiveArray; |
| 120 | + use crate::validity::Validity; |
| 121 | + use crate::IntoArray as _; |
| 122 | + |
| 123 | + #[test] |
| 124 | + fn compute_statistics() { |
| 125 | + let ext_dtype = Arc::new(ExtDType::new( |
| 126 | + ExtID::new("timestamp".into()), |
| 127 | + DType::from(PType::I64).into(), |
| 128 | + None, |
| 129 | + )); |
| 130 | + let array = ExtensionArray::new( |
| 131 | + ext_dtype.clone(), |
| 132 | + PrimitiveArray::from_vec(vec![1i64, 2, 3, 4, 5], Validity::NonNullable).into_array(), |
| 133 | + ); |
| 134 | + |
| 135 | + let stats = array |
| 136 | + .statistics() |
| 137 | + .compute_all(&[Stat::Min, Stat::Max, Stat::NullCount]) |
| 138 | + .unwrap(); |
| 139 | + let num_stats = stats.clone().into_iter().try_len().unwrap(); |
| 140 | + assert!( |
| 141 | + num_stats >= 3, |
| 142 | + "Expected at least 3 stats, got {}", |
| 143 | + num_stats |
| 144 | + ); |
| 145 | + |
| 146 | + assert_eq!( |
| 147 | + stats.get(Stat::Min), |
| 148 | + Some(&Scalar::extension( |
| 149 | + ext_dtype.clone(), |
| 150 | + ScalarValue::Primitive(PValue::I64(1)) |
| 151 | + )) |
| 152 | + ); |
| 153 | + assert_eq!( |
| 154 | + stats.get(Stat::Max), |
| 155 | + Some(&Scalar::extension( |
| 156 | + ext_dtype.clone(), |
| 157 | + ScalarValue::Primitive(PValue::I64(5)) |
| 158 | + )) |
| 159 | + ); |
| 160 | + assert_eq!(stats.get(Stat::NullCount), Some(&0u64.into())); |
| 161 | + } |
97 | 162 | } |
0 commit comments