Skip to content

Commit 02db2a1

Browse files
committed
chore[dtype]: use NonZero fro decimal precision
Signed-off-by: Joe Isaacs <[email protected]>
1 parent f046c8a commit 02db2a1

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

vortex-dtype/src/decimal/mod.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod precision;
66
mod types;
77

88
use std::fmt::{Display, Formatter};
9+
use std::num::NonZero;
910

1011
use num_traits::ToPrimitive;
1112
pub use precision::*;
@@ -23,7 +24,7 @@ const MAX_SCALE: i8 = <i256 as NativeDecimalType>::MAX_SCALE;
2324
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
2425
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2526
pub struct DecimalDType {
26-
precision: u8,
27+
precision: NonZero<u8>,
2728
scale: i8,
2829
}
2930

@@ -34,13 +35,14 @@ impl DecimalDType {
3435
///
3536
/// Returns an error if precision exceeds MAX_PRECISION or scale is outside [MIN_SCALE, MAX_SCALE].
3637
pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
37-
if precision == 0 {
38-
vortex_bail!(
38+
let precision = NonZero::new(precision).ok_or_else(|| {
39+
vortex_error::vortex_err!(
3940
"decimal precision must be between 1 and {} (inclusive)",
4041
MAX_PRECISION
41-
);
42-
}
43-
if precision > MAX_PRECISION {
42+
)
43+
})?;
44+
45+
if precision.get() > MAX_PRECISION {
4446
vortex_bail!(
4547
"decimal precision {} exceeds MAX_PRECISION {}",
4648
precision,
@@ -52,7 +54,7 @@ impl DecimalDType {
5254
vortex_bail!("decimal scale {} exceeds MAX_SCALE {}", scale, MAX_SCALE);
5355
}
5456

55-
if scale > 0 && scale as u8 > precision {
57+
if scale > 0 && scale as u8 > precision.get() {
5658
vortex_bail!(
5759
"decimal scale {} is greater than precision {}",
5860
scale,
@@ -76,7 +78,7 @@ impl DecimalDType {
7678

7779
/// The precision is the number of significant figures that the decimal tracks.
7880
pub fn precision(&self) -> u8 {
79-
self.precision
81+
self.precision.get()
8082
}
8183

8284
/// The scale is the maximum number of digits relative to the decimal point.
@@ -89,7 +91,7 @@ impl DecimalDType {
8991

9092
/// Return the max number of bits required to fit a decimal with `precision` in.
9193
pub fn required_bit_width(&self) -> usize {
92-
(self.precision as f32 * 10.0f32.log(2.0))
94+
(self.precision.get() as f32 * 10.0f32.log(2.0))
9395
.ceil()
9496
.to_usize()
9597
.vortex_expect("too many bits required")

vortex-dtype/src/decimal/precision.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use std::any::type_name;
55
use std::fmt::Display;
66
use std::marker::PhantomData;
7+
use std::num::NonZero;
78

89
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
910

@@ -13,7 +14,7 @@ use crate::{DecimalDType, NativeDecimalType};
1314
/// by the native type `D`.
1415
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1516
pub struct PrecisionScale<D> {
16-
precision: u8,
17+
precision: NonZero<u8>,
1718
scale: i8,
1819
phantom: PhantomData<D>,
1920
}
@@ -42,13 +43,14 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
4243

4344
/// Try to create a new [`PrecisionScale`] with the given precision and scale.
4445
pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
45-
if precision == 0 {
46-
vortex_bail!(
46+
let precision = NonZero::new(precision).ok_or_else(|| {
47+
vortex_error::vortex_err!(
4748
"precision cannot be 0, has to be between [1, {}]",
4849
D::MAX_PRECISION
49-
);
50-
}
51-
if precision > D::MAX_PRECISION {
50+
)
51+
})?;
52+
53+
if precision.get() > D::MAX_PRECISION {
5254
vortex_bail!(
5355
"Precision {} is greater than max {}",
5456
precision,
@@ -58,7 +60,7 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
5860
if scale > D::MAX_SCALE {
5961
vortex_bail!("Scale {} is greater than max {}", scale, D::MAX_SCALE);
6062
}
61-
if scale > 0 && scale as u8 > precision {
63+
if scale > 0 && scale as u8 > precision.get() {
6264
vortex_bail!("Scale {} is greater than precision {}", scale, precision);
6365
}
6466
Ok(Self {
@@ -78,7 +80,8 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
7880
Self::new(precision, scale)
7981
} else {
8082
Self {
81-
precision,
83+
// SAFETY: Caller guarantees precision is non-zero
84+
precision: unsafe { NonZero::new_unchecked(precision) },
8285
scale,
8386
phantom: Default::default(),
8487
}
@@ -88,7 +91,7 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
8891
/// The precision is the number of significant figures that the decimal tracks.
8992
#[inline(always)]
9093
pub fn precision(&self) -> u8 {
91-
self.precision
94+
self.precision.get()
9295
}
9396

9497
/// The scale is the maximum number of digits relative to the decimal point.
@@ -100,14 +103,15 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
100103
/// Validate whether a given value of type `D` fits within the precision and scale.
101104
#[inline]
102105
pub fn is_valid(&self, value: D) -> bool {
103-
self.precision <= D::MAX_PRECISION
104-
&& value >= D::MIN_BY_PRECISION[self.precision as usize]
105-
&& value <= D::MAX_BY_PRECISION[self.precision as usize]
106+
self.precision.get() <= D::MAX_PRECISION
107+
&& value >= D::MIN_BY_PRECISION[self.precision.get() as usize]
108+
&& value <= D::MAX_BY_PRECISION[self.precision.get() as usize]
106109
}
107110
}
108111

109112
impl<D: NativeDecimalType> From<PrecisionScale<D>> for DecimalDType {
110113
fn from(value: PrecisionScale<D>) -> Self {
114+
// SAFETY: precision is already NonZero<u8>, so we can use it directly
111115
DecimalDType {
112116
precision: value.precision,
113117
scale: value.scale,
@@ -119,6 +123,6 @@ impl<D: NativeDecimalType> TryFrom<&DecimalDType> for PrecisionScale<D> {
119123
type Error = vortex_error::VortexError;
120124

121125
fn try_from(value: &DecimalDType) -> VortexResult<Self> {
122-
PrecisionScale::try_new(value.precision, value.scale)
126+
PrecisionScale::try_new(value.precision(), value.scale)
123127
}
124128
}

0 commit comments

Comments
 (0)