1+ use std:: fmt:: Debug ;
2+
13use enum_iterator:: { Sequence , all} ;
24use num_traits:: CheckedAdd ;
35use vortex_dtype:: DType ;
46use vortex_error:: { VortexExpect , VortexResult , vortex_err} ;
57use vortex_scalar:: { Scalar , ScalarValue } ;
68
79use super :: traits:: StatsProvider ;
10+ use super :: { IsSorted , IsStrictSorted , NullCount , StatType , UncompressedSizeInBytes } ;
811use crate :: stats:: { IsConstant , Max , Min , Precision , Stat , StatBound , StatsProviderExt , Sum } ;
912
1013#[ derive( Default , Debug , Clone ) ]
@@ -229,72 +232,76 @@ impl StatsSet {
229232
230233 // given two sets of stats (of differing precision) for the same array, combine them
231234 pub fn combine_sets ( & mut self , other : & Self , dtype : & DType ) -> VortexResult < ( ) > {
232- self . combine_max ( other, dtype) ?;
233- self . combine_min ( other, dtype) ?;
234- self . combine_is_constant ( other)
235- }
236-
237- fn combine_min ( & mut self , other : & Self , dtype : & DType ) -> VortexResult < ( ) > {
238- match (
239- self . get_scalar_bound :: < Min > ( dtype) ,
240- other. get_scalar_bound :: < Min > ( dtype) ,
241- ) {
242- ( Some ( m1) , Some ( m2) ) => {
243- let meet = m1
244- . intersection ( & m2)
245- . vortex_expect ( "can always compare scalar" )
246- . ok_or_else ( || vortex_err ! ( "Min bounds ({m1:?}, {m2:?}) do not overlap" ) ) ?;
247- if meet != m1 {
248- self . set ( Stat :: Min , meet. into_value ( ) . map ( Scalar :: into_value) ) ;
235+ let other_stats: Vec < _ > = other. values . iter ( ) . map ( |( stat, _) | * stat) . collect ( ) ;
236+ for s in other_stats {
237+ match s {
238+ Stat :: Max => self . combine_bound :: < Max > ( other, dtype) ?,
239+ Stat :: Min => self . combine_bound :: < Min > ( other, dtype) ?,
240+ Stat :: UncompressedSizeInBytes => {
241+ self . combine_bound :: < UncompressedSizeInBytes > ( other, dtype) ?
249242 }
243+ Stat :: IsConstant => self . combine_bool_stat :: < IsConstant > ( other) ?,
244+ Stat :: IsSorted => self . combine_bool_stat :: < IsSorted > ( other) ?,
245+ Stat :: IsStrictSorted => self . combine_bool_stat :: < IsStrictSorted > ( other) ?,
246+ Stat :: NullCount => self . combine_bound :: < NullCount > ( other, dtype) ?,
247+ Stat :: Sum => self . combine_bound :: < Sum > ( other, dtype) ?,
250248 }
251- ( None , Some ( m) ) => self . set ( Stat :: Min , m. into_value ( ) . map ( Scalar :: into_value) ) ,
252- ( Some ( _) , _) => ( ) ,
253- ( None , None ) => self . clear ( Stat :: Min ) ,
254249 }
255250 Ok ( ( ) )
256251 }
257252
258- fn combine_max ( & mut self , other : & Self , dtype : & DType ) -> VortexResult < ( ) > {
253+ fn combine_bound < S : StatType < Scalar > > (
254+ & mut self ,
255+ other : & Self ,
256+ dtype : & DType ,
257+ ) -> VortexResult < ( ) >
258+ where
259+ S :: Bound : StatBound < Scalar > + Debug + Eq + PartialEq ,
260+ {
259261 match (
260- self . get_scalar_bound :: < Max > ( dtype) ,
261- other. get_scalar_bound :: < Max > ( dtype) ,
262+ self . get_scalar_bound :: < S > ( dtype) ,
263+ other. get_scalar_bound :: < S > ( dtype) ,
262264 ) {
263265 ( Some ( m1) , Some ( m2) ) => {
264266 let meet = m1
265267 . intersection ( & m2)
266268 . vortex_expect ( "can always compare scalar" )
267- . ok_or_else ( || vortex_err ! ( "Max bounds ({m1:?}, {m2:?}) do not overlap" ) ) ?;
269+ . ok_or_else ( || {
270+ vortex_err ! ( "{:?} bounds ({m1:?}, {m2:?}) do not overlap" , S :: STAT )
271+ } ) ?;
268272 if meet != m1 {
269- self . set ( Stat :: Max , meet. into_value ( ) . map ( Scalar :: into_value) ) ;
273+ self . set ( S :: STAT , meet. into_value ( ) . map ( Scalar :: into_value) ) ;
270274 }
271275 }
272- ( None , Some ( m) ) => self . set ( Stat :: Max , m. into_value ( ) . map ( Scalar :: into_value) ) ,
273- ( Some ( _) , None ) => ( ) ,
274- ( None , None ) => self . clear ( Stat :: Max ) ,
276+ ( None , Some ( m) ) => self . set ( S :: STAT , m. into_value ( ) . map ( Scalar :: into_value) ) ,
277+ ( Some ( _) , _ ) => ( ) ,
278+ ( None , None ) => self . clear ( S :: STAT ) ,
275279 }
276280 Ok ( ( ) )
277281 }
278282
279- fn combine_is_constant ( & mut self , other : & Self ) -> VortexResult < ( ) > {
283+ fn combine_bool_stat < S : StatType < bool > > ( & mut self , other : & Self ) -> VortexResult < ( ) >
284+ where
285+ S :: Bound : StatBound < bool > + Debug + Eq + PartialEq ,
286+ {
280287 match (
281- self . get_as_bound :: < IsConstant , bool > ( ) ,
282- other. get_as_bound :: < IsConstant , bool > ( ) ,
288+ self . get_as_bound :: < S , bool > ( ) ,
289+ other. get_as_bound :: < S , bool > ( ) ,
283290 ) {
284291 ( Some ( m1) , Some ( m2) ) => {
285292 let intersection = m1
286293 . intersection ( & m2)
287- . vortex_expect ( "can always compare scalar " )
294+ . vortex_expect ( "can always compare boolean " )
288295 . ok_or_else ( || {
289- vortex_err ! ( "IsConstant bounds ({m1:?}, {m2:?}) do not overlap" )
296+ vortex_err ! ( "{:?} bounds ({m1:?}, {m2:?}) do not overlap" , S :: STAT )
290297 } ) ?;
291298 if intersection != m1 {
292- self . set ( Stat :: IsConstant , intersection. map ( ScalarValue :: from) ) ;
299+ self . set ( S :: STAT , intersection. into_value ( ) . map ( ScalarValue :: from) ) ;
293300 }
294301 }
295- ( None , Some ( m) ) => self . set ( Stat :: IsConstant , m. map ( ScalarValue :: from) ) ,
302+ ( None , Some ( m) ) => self . set ( S :: STAT , m. into_value ( ) . map ( ScalarValue :: from) ) ,
296303 ( Some ( _) , None ) => ( ) ,
297- ( None , None ) => self . clear ( Stat :: IsConstant ) ,
304+ ( None , None ) => self . clear ( S :: STAT ) ,
298305 }
299306 Ok ( ( ) )
300307 }
@@ -460,7 +467,7 @@ mod test {
460467
461468 use crate :: Array ;
462469 use crate :: arrays:: PrimitiveArray ;
463- use crate :: stats:: { Precision , Stat , StatsProvider , StatsProviderExt , StatsSet } ;
470+ use crate :: stats:: { IsConstant , Precision , Stat , StatsProvider , StatsProviderExt , StatsSet } ;
464471
465472 #[ test]
466473 fn test_iter ( ) {
@@ -789,7 +796,7 @@ mod test {
789796 {
790797 let mut stats = StatsSet :: of ( Stat :: IsConstant , Precision :: exact ( true ) ) ;
791798 let stats2 = StatsSet :: of ( Stat :: IsConstant , Precision :: exact ( true ) ) ;
792- stats. combine_is_constant ( & stats2) . unwrap ( ) ;
799+ stats. combine_bool_stat :: < IsConstant > ( & stats2) . unwrap ( ) ;
793800 assert_eq ! (
794801 stats. get_as:: <bool >( Stat :: IsConstant ) ,
795802 Some ( Precision :: exact( true ) )
@@ -799,7 +806,7 @@ mod test {
799806 {
800807 let mut stats = StatsSet :: of ( Stat :: IsConstant , Precision :: exact ( true ) ) ;
801808 let stats2 = StatsSet :: of ( Stat :: IsConstant , Precision :: inexact ( false ) ) ;
802- stats. combine_is_constant ( & stats2) . unwrap ( ) ;
809+ stats. combine_bool_stat :: < IsConstant > ( & stats2) . unwrap ( ) ;
803810 assert_eq ! (
804811 stats. get_as:: <bool >( Stat :: IsConstant ) ,
805812 Some ( Precision :: exact( true ) )
@@ -809,11 +816,93 @@ mod test {
809816 {
810817 let mut stats = StatsSet :: of ( Stat :: IsConstant , Precision :: exact ( false ) ) ;
811818 let stats2 = StatsSet :: of ( Stat :: IsConstant , Precision :: inexact ( false ) ) ;
812- stats. combine_is_constant ( & stats2) . unwrap ( ) ;
819+ stats. combine_bool_stat :: < IsConstant > ( & stats2) . unwrap ( ) ;
813820 assert_eq ! (
814821 stats. get_as:: <bool >( Stat :: IsConstant ) ,
815822 Some ( Precision :: exact( false ) )
816823 ) ;
817824 }
818825 }
826+
827+ #[ test]
828+ fn test_combine_sets_boolean_conflict ( ) {
829+ let mut stats1 = StatsSet :: from_iter ( [
830+ ( Stat :: IsConstant , Precision :: exact ( true ) ) ,
831+ ( Stat :: IsSorted , Precision :: exact ( true ) ) ,
832+ ] ) ;
833+
834+ let stats2 = StatsSet :: from_iter ( [
835+ ( Stat :: IsConstant , Precision :: exact ( false ) ) ,
836+ ( Stat :: IsSorted , Precision :: exact ( true ) ) ,
837+ ] ) ;
838+
839+ let result = stats1. combine_sets (
840+ & stats2,
841+ & DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ,
842+ ) ;
843+ assert ! ( result. is_err( ) ) ;
844+ }
845+
846+ #[ test]
847+ fn test_combine_sets_with_missing_stats ( ) {
848+ let mut stats1 = StatsSet :: from_iter ( [
849+ ( Stat :: Min , Precision :: exact ( 42 ) ) ,
850+ ( Stat :: UncompressedSizeInBytes , Precision :: exact ( 1000 ) ) ,
851+ ] ) ;
852+
853+ let stats2 = StatsSet :: from_iter ( [
854+ ( Stat :: Max , Precision :: exact ( 100 ) ) ,
855+ ( Stat :: IsStrictSorted , Precision :: exact ( true ) ) ,
856+ ] ) ;
857+
858+ stats1
859+ . combine_sets (
860+ & stats2,
861+ & DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ,
862+ )
863+ . unwrap ( ) ;
864+
865+ // Min should remain unchanged
866+ assert_eq ! ( stats1. get_as:: <i32 >( Stat :: Min ) , Some ( Precision :: exact( 42 ) ) ) ;
867+ // Max should be added
868+ assert_eq ! ( stats1. get_as:: <i32 >( Stat :: Max ) , Some ( Precision :: exact( 100 ) ) ) ;
869+ // IsStrictSorted should be added
870+ assert_eq ! (
871+ stats1. get_as:: <bool >( Stat :: IsStrictSorted ) ,
872+ Some ( Precision :: exact( true ) )
873+ ) ;
874+ }
875+
876+ #[ test]
877+ fn test_combine_sets_with_inexact ( ) {
878+ let mut stats1 = StatsSet :: from_iter ( [
879+ ( Stat :: Min , Precision :: exact ( 42 ) ) ,
880+ ( Stat :: Max , Precision :: inexact ( 100 ) ) ,
881+ ( Stat :: IsConstant , Precision :: exact ( false ) ) ,
882+ ] ) ;
883+
884+ let stats2 = StatsSet :: from_iter ( [
885+ // Must ensure Min from stats2 is <= Min from stats1
886+ ( Stat :: Min , Precision :: inexact ( 40 ) ) ,
887+ ( Stat :: Max , Precision :: exact ( 90 ) ) ,
888+ ( Stat :: IsSorted , Precision :: exact ( true ) ) ,
889+ ] ) ;
890+
891+ stats1
892+ . combine_sets (
893+ & stats2,
894+ & DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ,
895+ )
896+ . unwrap ( ) ;
897+
898+ // Min should remain unchanged since it's more restrictive than the inexact value
899+ assert_eq ! ( stats1. get_as:: <i32 >( Stat :: Min ) , Some ( Precision :: exact( 42 ) ) ) ;
900+ // Check that max was updated with the exact value
901+ assert_eq ! ( stats1. get_as:: <i32 >( Stat :: Max ) , Some ( Precision :: exact( 90 ) ) ) ;
902+ // Check that IsSorted was added
903+ assert_eq ! (
904+ stats1. get_as:: <bool >( Stat :: IsSorted ) ,
905+ Some ( Precision :: exact( true ) )
906+ ) ;
907+ }
819908}
0 commit comments