Skip to content

Commit 6be61e3

Browse files
authored
slice stats propagation (#2788)
1 parent d36de28 commit 6be61e3

File tree

6 files changed

+155
-63
lines changed

6 files changed

+155
-63
lines changed

vortex-array/src/compute/slice.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ fn derive_sliced_stats(arr: &dyn Array) -> StatsSet {
8989

9090
// an array that is not constant can become constant after slicing
9191
let is_constant = stats.get_as::<bool>(Stat::IsConstant);
92-
let is_sorted = stats.get_as::<bool>(Stat::IsConstant);
93-
let is_strict_sorted = stats.get_as::<bool>(Stat::IsConstant);
92+
let is_sorted = stats.get_as::<bool>(Stat::IsSorted);
93+
let is_strict_sorted = stats.get_as::<bool>(Stat::IsStrictSorted);
9494

9595
let mut stats = stats.keep_inexact_stats(&[
9696
Stat::Max,

vortex-array/src/stats/bound.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub struct LowerBound<T>(pub(crate) Precision<T>);
1313

1414
impl<T> LowerBound<T> {
1515
pub(crate) fn min_value(self) -> T {
16-
self.0.into_value()
16+
self.0.into_inner()
1717
}
1818
}
1919

@@ -108,6 +108,10 @@ impl<T: PartialOrd + Clone> StatBound<T> for LowerBound<T> {
108108
fn to_exact(&self) -> Option<&T> {
109109
self.0.to_exact()
110110
}
111+
112+
fn into_value(self) -> Precision<T> {
113+
self.0
114+
}
111115
}
112116

113117
impl<T: PartialOrd> PartialEq<T> for LowerBound<T> {
@@ -137,13 +141,7 @@ pub struct UpperBound<T>(pub(crate) Precision<T>);
137141

138142
impl<T> UpperBound<T> {
139143
pub(crate) fn max_value(self) -> T {
140-
self.0.into_value()
141-
}
142-
}
143-
144-
impl<T> UpperBound<T> {
145-
pub fn into_value(self) -> Precision<T> {
146-
self.0
144+
self.0.into_inner()
147145
}
148146
}
149147

@@ -209,6 +207,10 @@ impl<T: PartialOrd + Clone> StatBound<T> for UpperBound<T> {
209207
fn to_exact(&self) -> Option<&T> {
210208
self.0.to_exact()
211209
}
210+
211+
fn into_value(self) -> Precision<T> {
212+
self.0
213+
}
212214
}
213215

214216
impl<T: PartialOrd> PartialEq<T> for UpperBound<T> {

vortex-array/src/stats/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ impl StatType<bool> for IsConstant {
9494
const STAT: Stat = Stat::IsConstant;
9595
}
9696

97-
impl<T: PartialOrd + Clone> StatType<T> for IsSorted {
98-
type Bound = Precision<T>;
97+
impl StatType<bool> for IsSorted {
98+
type Bound = Precision<bool>;
9999

100100
const STAT: Stat = Stat::IsSorted;
101101
}
102102

103-
impl<T: PartialOrd + Clone> StatType<T> for IsStrictSorted {
104-
type Bound = Precision<T>;
103+
impl StatType<bool> for IsStrictSorted {
104+
type Bound = Precision<bool>;
105105

106106
const STAT: Stat = Stat::IsStrictSorted;
107107
}

vortex-array/src/stats/precision.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl<T> Precision<T> {
129129
Ok(precision)
130130
}
131131

132-
pub(crate) fn into_value(self) -> T {
132+
pub(crate) fn into_inner(self) -> T {
133133
match self {
134134
Exact(val) | Inexact(val) => val,
135135
}

vortex-array/src/stats/stat_bound.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::cmp::Ordering;
22

33
use crate::partial_ord::partial_min;
44
use crate::stats::bound::IntersectionResult;
5-
use crate::stats::{LowerBound, Precision, Stat};
5+
use crate::stats::{Precision, Stat};
66

77
/// `StatType` define the bound of a given statistic. (e.g. `Max` is an upper bound),
88
/// this is used to extract the bound from a `Precision` value, (e.g. `p::bound<Max>()`).
@@ -13,11 +13,14 @@ pub trait StatType<T> {
1313
}
1414

1515
/// `StatBound` defines the operations that can be performed on a bound.
16-
/// The mains bounds are Upper (e.g. max) and Lower (e.g. min).
16+
/// The main bounds are Upper (e.g. max) and Lower (e.g. min).
1717
pub trait StatBound<T>: Sized {
1818
/// Creates a new bound from a Precision statistic.
1919
fn lift(value: Precision<T>) -> Self;
2020

21+
/// Converts `Self` back to `Precision<T>`, inverse of `lift`.
22+
fn into_value(self) -> Precision<T>;
23+
2124
/// Finds the smallest bound that covers both bounds.
2225
/// A.k.a. the `meet` of the bound.
2326
fn union(&self, other: &Self) -> Option<Self>;
@@ -40,12 +43,6 @@ impl<T> Precision<T> {
4043
}
4144
}
4245

43-
impl<T: PartialOrd + Clone> LowerBound<T> {
44-
pub fn into_value(self) -> Precision<T> {
45-
self.0
46-
}
47-
}
48-
4946
impl<T: PartialOrd + Clone> StatBound<T> for Precision<T> {
5047
fn lift(value: Precision<T>) -> Self {
5148
value
@@ -87,4 +84,8 @@ impl<T: PartialOrd + Clone> StatBound<T> for Precision<T> {
8784
_ => None,
8885
}
8986
}
87+
88+
fn into_value(self) -> Precision<T> {
89+
self
90+
}
9091
}

vortex-array/src/stats/stats_set.rs

Lines changed: 129 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
use std::fmt::Debug;
2+
13
use enum_iterator::{Sequence, all};
24
use num_traits::CheckedAdd;
35
use vortex_dtype::DType;
46
use vortex_error::{VortexExpect, VortexResult, vortex_err};
57
use vortex_scalar::{Scalar, ScalarValue};
68

79
use super::traits::StatsProvider;
10+
use super::{IsSorted, IsStrictSorted, NullCount, StatType, UncompressedSizeInBytes};
811
use crate::stats::{IsConstant, Max, Min, Precision, Stat, StatBound, StatsProviderExt, Sum};
912

1013
#[derive(Default, Debug, Clone)]
@@ -229,72 +232,76 @@ impl StatsSet {
229232

230233
// given two sets of stats (of differing precision) for the same array, combine them
231234
pub fn combine_sets(&mut self, other: &Self, dtype: &DType) -> VortexResult<()> {
232-
self.combine_max(other, dtype)?;
233-
self.combine_min(other, dtype)?;
234-
self.combine_is_constant(other)
235-
}
236-
237-
fn combine_min(&mut self, other: &Self, dtype: &DType) -> VortexResult<()> {
238-
match (
239-
self.get_scalar_bound::<Min>(dtype),
240-
other.get_scalar_bound::<Min>(dtype),
241-
) {
242-
(Some(m1), Some(m2)) => {
243-
let meet = m1
244-
.intersection(&m2)
245-
.vortex_expect("can always compare scalar")
246-
.ok_or_else(|| vortex_err!("Min bounds ({m1:?}, {m2:?}) do not overlap"))?;
247-
if meet != m1 {
248-
self.set(Stat::Min, meet.into_value().map(Scalar::into_value));
235+
let other_stats: Vec<_> = other.values.iter().map(|(stat, _)| *stat).collect();
236+
for s in other_stats {
237+
match s {
238+
Stat::Max => self.combine_bound::<Max>(other, dtype)?,
239+
Stat::Min => self.combine_bound::<Min>(other, dtype)?,
240+
Stat::UncompressedSizeInBytes => {
241+
self.combine_bound::<UncompressedSizeInBytes>(other, dtype)?
249242
}
243+
Stat::IsConstant => self.combine_bool_stat::<IsConstant>(other)?,
244+
Stat::IsSorted => self.combine_bool_stat::<IsSorted>(other)?,
245+
Stat::IsStrictSorted => self.combine_bool_stat::<IsStrictSorted>(other)?,
246+
Stat::NullCount => self.combine_bound::<NullCount>(other, dtype)?,
247+
Stat::Sum => self.combine_bound::<Sum>(other, dtype)?,
250248
}
251-
(None, Some(m)) => self.set(Stat::Min, m.into_value().map(Scalar::into_value)),
252-
(Some(_), _) => (),
253-
(None, None) => self.clear(Stat::Min),
254249
}
255250
Ok(())
256251
}
257252

258-
fn combine_max(&mut self, other: &Self, dtype: &DType) -> VortexResult<()> {
253+
fn combine_bound<S: StatType<Scalar>>(
254+
&mut self,
255+
other: &Self,
256+
dtype: &DType,
257+
) -> VortexResult<()>
258+
where
259+
S::Bound: StatBound<Scalar> + Debug + Eq + PartialEq,
260+
{
259261
match (
260-
self.get_scalar_bound::<Max>(dtype),
261-
other.get_scalar_bound::<Max>(dtype),
262+
self.get_scalar_bound::<S>(dtype),
263+
other.get_scalar_bound::<S>(dtype),
262264
) {
263265
(Some(m1), Some(m2)) => {
264266
let meet = m1
265267
.intersection(&m2)
266268
.vortex_expect("can always compare scalar")
267-
.ok_or_else(|| vortex_err!("Max bounds ({m1:?}, {m2:?}) do not overlap"))?;
269+
.ok_or_else(|| {
270+
vortex_err!("{:?} bounds ({m1:?}, {m2:?}) do not overlap", S::STAT)
271+
})?;
268272
if meet != m1 {
269-
self.set(Stat::Max, meet.into_value().map(Scalar::into_value));
273+
self.set(S::STAT, meet.into_value().map(Scalar::into_value));
270274
}
271275
}
272-
(None, Some(m)) => self.set(Stat::Max, m.into_value().map(Scalar::into_value)),
273-
(Some(_), None) => (),
274-
(None, None) => self.clear(Stat::Max),
276+
(None, Some(m)) => self.set(S::STAT, m.into_value().map(Scalar::into_value)),
277+
(Some(_), _) => (),
278+
(None, None) => self.clear(S::STAT),
275279
}
276280
Ok(())
277281
}
278282

279-
fn combine_is_constant(&mut self, other: &Self) -> VortexResult<()> {
283+
fn combine_bool_stat<S: StatType<bool>>(&mut self, other: &Self) -> VortexResult<()>
284+
where
285+
S::Bound: StatBound<bool> + Debug + Eq + PartialEq,
286+
{
280287
match (
281-
self.get_as_bound::<IsConstant, bool>(),
282-
other.get_as_bound::<IsConstant, bool>(),
288+
self.get_as_bound::<S, bool>(),
289+
other.get_as_bound::<S, bool>(),
283290
) {
284291
(Some(m1), Some(m2)) => {
285292
let intersection = m1
286293
.intersection(&m2)
287-
.vortex_expect("can always compare scalar")
294+
.vortex_expect("can always compare boolean")
288295
.ok_or_else(|| {
289-
vortex_err!("IsConstant bounds ({m1:?}, {m2:?}) do not overlap")
296+
vortex_err!("{:?} bounds ({m1:?}, {m2:?}) do not overlap", S::STAT)
290297
})?;
291298
if intersection != m1 {
292-
self.set(Stat::IsConstant, intersection.map(ScalarValue::from));
299+
self.set(S::STAT, intersection.into_value().map(ScalarValue::from));
293300
}
294301
}
295-
(None, Some(m)) => self.set(Stat::IsConstant, m.map(ScalarValue::from)),
302+
(None, Some(m)) => self.set(S::STAT, m.into_value().map(ScalarValue::from)),
296303
(Some(_), None) => (),
297-
(None, None) => self.clear(Stat::IsConstant),
304+
(None, None) => self.clear(S::STAT),
298305
}
299306
Ok(())
300307
}
@@ -460,7 +467,7 @@ mod test {
460467

461468
use crate::Array;
462469
use crate::arrays::PrimitiveArray;
463-
use crate::stats::{Precision, Stat, StatsProvider, StatsProviderExt, StatsSet};
470+
use crate::stats::{IsConstant, Precision, Stat, StatsProvider, StatsProviderExt, StatsSet};
464471

465472
#[test]
466473
fn test_iter() {
@@ -789,7 +796,7 @@ mod test {
789796
{
790797
let mut stats = StatsSet::of(Stat::IsConstant, Precision::exact(true));
791798
let stats2 = StatsSet::of(Stat::IsConstant, Precision::exact(true));
792-
stats.combine_is_constant(&stats2).unwrap();
799+
stats.combine_bool_stat::<IsConstant>(&stats2).unwrap();
793800
assert_eq!(
794801
stats.get_as::<bool>(Stat::IsConstant),
795802
Some(Precision::exact(true))
@@ -799,7 +806,7 @@ mod test {
799806
{
800807
let mut stats = StatsSet::of(Stat::IsConstant, Precision::exact(true));
801808
let stats2 = StatsSet::of(Stat::IsConstant, Precision::inexact(false));
802-
stats.combine_is_constant(&stats2).unwrap();
809+
stats.combine_bool_stat::<IsConstant>(&stats2).unwrap();
803810
assert_eq!(
804811
stats.get_as::<bool>(Stat::IsConstant),
805812
Some(Precision::exact(true))
@@ -809,11 +816,93 @@ mod test {
809816
{
810817
let mut stats = StatsSet::of(Stat::IsConstant, Precision::exact(false));
811818
let stats2 = StatsSet::of(Stat::IsConstant, Precision::inexact(false));
812-
stats.combine_is_constant(&stats2).unwrap();
819+
stats.combine_bool_stat::<IsConstant>(&stats2).unwrap();
813820
assert_eq!(
814821
stats.get_as::<bool>(Stat::IsConstant),
815822
Some(Precision::exact(false))
816823
);
817824
}
818825
}
826+
827+
#[test]
828+
fn test_combine_sets_boolean_conflict() {
829+
let mut stats1 = StatsSet::from_iter([
830+
(Stat::IsConstant, Precision::exact(true)),
831+
(Stat::IsSorted, Precision::exact(true)),
832+
]);
833+
834+
let stats2 = StatsSet::from_iter([
835+
(Stat::IsConstant, Precision::exact(false)),
836+
(Stat::IsSorted, Precision::exact(true)),
837+
]);
838+
839+
let result = stats1.combine_sets(
840+
&stats2,
841+
&DType::Primitive(PType::I32, Nullability::NonNullable),
842+
);
843+
assert!(result.is_err());
844+
}
845+
846+
#[test]
847+
fn test_combine_sets_with_missing_stats() {
848+
let mut stats1 = StatsSet::from_iter([
849+
(Stat::Min, Precision::exact(42)),
850+
(Stat::UncompressedSizeInBytes, Precision::exact(1000)),
851+
]);
852+
853+
let stats2 = StatsSet::from_iter([
854+
(Stat::Max, Precision::exact(100)),
855+
(Stat::IsStrictSorted, Precision::exact(true)),
856+
]);
857+
858+
stats1
859+
.combine_sets(
860+
&stats2,
861+
&DType::Primitive(PType::I32, Nullability::NonNullable),
862+
)
863+
.unwrap();
864+
865+
// Min should remain unchanged
866+
assert_eq!(stats1.get_as::<i32>(Stat::Min), Some(Precision::exact(42)));
867+
// Max should be added
868+
assert_eq!(stats1.get_as::<i32>(Stat::Max), Some(Precision::exact(100)));
869+
// IsStrictSorted should be added
870+
assert_eq!(
871+
stats1.get_as::<bool>(Stat::IsStrictSorted),
872+
Some(Precision::exact(true))
873+
);
874+
}
875+
876+
#[test]
877+
fn test_combine_sets_with_inexact() {
878+
let mut stats1 = StatsSet::from_iter([
879+
(Stat::Min, Precision::exact(42)),
880+
(Stat::Max, Precision::inexact(100)),
881+
(Stat::IsConstant, Precision::exact(false)),
882+
]);
883+
884+
let stats2 = StatsSet::from_iter([
885+
// Must ensure Min from stats2 is <= Min from stats1
886+
(Stat::Min, Precision::inexact(40)),
887+
(Stat::Max, Precision::exact(90)),
888+
(Stat::IsSorted, Precision::exact(true)),
889+
]);
890+
891+
stats1
892+
.combine_sets(
893+
&stats2,
894+
&DType::Primitive(PType::I32, Nullability::NonNullable),
895+
)
896+
.unwrap();
897+
898+
// Min should remain unchanged since it's more restrictive than the inexact value
899+
assert_eq!(stats1.get_as::<i32>(Stat::Min), Some(Precision::exact(42)));
900+
// Check that max was updated with the exact value
901+
assert_eq!(stats1.get_as::<i32>(Stat::Max), Some(Precision::exact(90)));
902+
// Check that IsSorted was added
903+
assert_eq!(
904+
stats1.get_as::<bool>(Stat::IsSorted),
905+
Some(Precision::exact(true))
906+
);
907+
}
819908
}

0 commit comments

Comments
 (0)