44use std:: any:: type_name;
55use std:: fmt:: Display ;
66use std:: marker:: PhantomData ;
7+ use std:: num:: NonZero ;
78
8- use vortex_error:: { VortexExpect , VortexResult , vortex_bail} ;
9+ use vortex_error:: { VortexExpect , VortexResult , vortex_bail, vortex_err } ;
910
1011use crate :: { DecimalDType , NativeDecimalType } ;
1112
1213/// A struct representing the precision and scale of a decimal type, to be represented
1314/// by the native type `D`.
1415#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
1516pub struct PrecisionScale < D > {
16- precision : u8 ,
17+ precision : NonZero < u8 > ,
1718 scale : i8 ,
1819 phantom : PhantomData < D > ,
1920}
@@ -42,13 +43,14 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
4243
4344 /// Try to create a new [`PrecisionScale`] with the given precision and scale.
4445 pub fn try_new ( precision : u8 , scale : i8 ) -> VortexResult < Self > {
45- if precision == 0 {
46- vortex_bail ! (
46+ let precision = NonZero :: new ( precision ) . ok_or_else ( || {
47+ vortex_err ! (
4748 "precision cannot be 0, has to be between [1, {}]" ,
4849 D :: MAX_PRECISION
49- ) ;
50- }
51- if precision > D :: MAX_PRECISION {
50+ )
51+ } ) ?;
52+
53+ if precision. get ( ) > D :: MAX_PRECISION {
5254 vortex_bail ! (
5355 "Precision {} is greater than max {}" ,
5456 precision,
@@ -58,7 +60,7 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
5860 if scale > D :: MAX_SCALE {
5961 vortex_bail ! ( "Scale {} is greater than max {}" , scale, D :: MAX_SCALE ) ;
6062 }
61- if scale > 0 && scale as u8 > precision {
63+ if scale > 0 && scale as u8 > precision. get ( ) {
6264 vortex_bail ! ( "Scale {} is greater than precision {}" , scale, precision) ;
6365 }
6466 Ok ( Self {
@@ -78,7 +80,8 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
7880 Self :: new ( precision, scale)
7981 } else {
8082 Self {
81- precision,
83+ // SAFETY: Caller guarantees precision is non-zero
84+ precision : unsafe { NonZero :: new_unchecked ( precision) } ,
8285 scale,
8386 phantom : Default :: default ( ) ,
8487 }
@@ -88,7 +91,7 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
8891 /// The precision is the number of significant figures that the decimal tracks.
8992 #[ inline( always) ]
9093 pub fn precision ( & self ) -> u8 {
91- self . precision
94+ self . precision . get ( )
9295 }
9396
9497 /// The scale is the maximum number of digits relative to the decimal point.
@@ -100,14 +103,15 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
100103 /// Validate whether a given value of type `D` fits within the precision and scale.
101104 #[ inline]
102105 pub fn is_valid ( & self , value : D ) -> bool {
103- self . precision <= D :: MAX_PRECISION
104- && value >= D :: MIN_BY_PRECISION [ self . precision as usize ]
105- && value <= D :: MAX_BY_PRECISION [ self . precision as usize ]
106+ self . precision . get ( ) <= D :: MAX_PRECISION
107+ && value >= D :: MIN_BY_PRECISION [ self . precision . get ( ) as usize ]
108+ && value <= D :: MAX_BY_PRECISION [ self . precision . get ( ) as usize ]
106109 }
107110}
108111
109112impl < D : NativeDecimalType > From < PrecisionScale < D > > for DecimalDType {
110113 fn from ( value : PrecisionScale < D > ) -> Self {
114+ // SAFETY: precision is already NonZero<u8>, so we can use it directly
111115 DecimalDType {
112116 precision : value. precision ,
113117 scale : value. scale ,
@@ -119,7 +123,7 @@ impl<D: NativeDecimalType> TryFrom<&DecimalDType> for PrecisionScale<D> {
119123 type Error = vortex_error:: VortexError ;
120124
121125 fn try_from ( value : & DecimalDType ) -> VortexResult < Self > {
122- PrecisionScale :: try_new ( value. precision , value. scale )
126+ PrecisionScale :: try_new ( value. precision ( ) , value. scale )
123127 }
124128}
125129
0 commit comments