Skip to content

Commit d47e47b

Browse files
chore[dtype]: use NonZero for decimal precision (#5146)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent a008ea6 commit d47e47b

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

vortex-dtype/src/decimal/mod.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ mod precision;
66
mod types;
77

88
use std::fmt::{Display, Formatter};
9+
use std::num::NonZero;
910

1011
use num_traits::ToPrimitive;
1112
pub use precision::*;
1213
pub use types::*;
13-
use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_panic};
14+
use vortex_error::{
15+
VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
16+
};
1417

1518
use crate::{DType, i256};
1619

@@ -20,10 +23,11 @@ const MAX_SCALE: i8 = <i256 as NativeDecimalType>::MAX_SCALE;
2023
/// Parameters that define the precision and scale of a decimal type.
2124
///
2225
/// Decimal types allow real numbers with a similar precision and scale to be represented exactly.
26+
/// Precision must be non-zero.
2327
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
2428
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2529
pub struct DecimalDType {
26-
precision: u8,
30+
precision: NonZero<u8>,
2731
scale: i8,
2832
}
2933

@@ -34,13 +38,14 @@ impl DecimalDType {
3438
///
3539
/// Returns an error if precision exceeds MAX_PRECISION or scale is outside [MIN_SCALE, MAX_SCALE].
3640
pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
37-
if precision == 0 {
38-
vortex_bail!(
41+
let precision = NonZero::new(precision).ok_or_else(|| {
42+
vortex_err!(
3943
"decimal precision must be between 1 and {} (inclusive)",
4044
MAX_PRECISION
41-
);
42-
}
43-
if precision > MAX_PRECISION {
45+
)
46+
})?;
47+
48+
if precision.get() > MAX_PRECISION {
4449
vortex_bail!(
4550
"decimal precision {} exceeds MAX_PRECISION {}",
4651
precision,
@@ -52,7 +57,7 @@ impl DecimalDType {
5257
vortex_bail!("decimal scale {} exceeds MAX_SCALE {}", scale, MAX_SCALE);
5358
}
5459

55-
if scale > 0 && scale as u8 > precision {
60+
if scale > 0 && scale as u8 > precision.get() {
5661
vortex_bail!(
5762
"decimal scale {} is greater than precision {}",
5863
scale,
@@ -76,7 +81,7 @@ impl DecimalDType {
7681

7782
/// The precision is the number of significant figures that the decimal tracks.
7883
pub fn precision(&self) -> u8 {
79-
self.precision
84+
self.precision.get()
8085
}
8186

8287
/// The scale is the maximum number of digits relative to the decimal point.
@@ -89,7 +94,7 @@ impl DecimalDType {
8994

9095
/// Return the max number of bits required to fit a decimal with `precision` in.
9196
pub fn required_bit_width(&self) -> usize {
92-
(self.precision as f32 * 10.0f32.log(2.0))
97+
(self.precision.get() as f32 * 10.0f32.log(2.0))
9398
.ceil()
9499
.to_usize()
95100
.vortex_expect("too many bits required")

vortex-dtype/src/decimal/precision.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
use std::any::type_name;
55
use std::fmt::Display;
66
use std::marker::PhantomData;
7+
use std::num::NonZero;
78

8-
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
910

1011
use crate::{DecimalDType, NativeDecimalType};
1112

1213
/// A struct representing the precision and scale of a decimal type, to be represented
1314
/// by the native type `D`.
1415
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1516
pub struct PrecisionScale<D> {
16-
precision: u8,
17+
precision: NonZero<u8>,
1718
scale: i8,
1819
phantom: PhantomData<D>,
1920
}
@@ -42,13 +43,14 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
4243

4344
/// Try to create a new [`PrecisionScale`] with the given precision and scale.
4445
pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
45-
if precision == 0 {
46-
vortex_bail!(
46+
let precision = NonZero::new(precision).ok_or_else(|| {
47+
vortex_err!(
4748
"precision cannot be 0, has to be between [1, {}]",
4849
D::MAX_PRECISION
49-
);
50-
}
51-
if precision > D::MAX_PRECISION {
50+
)
51+
})?;
52+
53+
if precision.get() > D::MAX_PRECISION {
5254
vortex_bail!(
5355
"Precision {} is greater than max {}",
5456
precision,
@@ -58,7 +60,7 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
5860
if scale > D::MAX_SCALE {
5961
vortex_bail!("Scale {} is greater than max {}", scale, D::MAX_SCALE);
6062
}
61-
if scale > 0 && scale as u8 > precision {
63+
if scale > 0 && scale as u8 > precision.get() {
6264
vortex_bail!("Scale {} is greater than precision {}", scale, precision);
6365
}
6466
Ok(Self {
@@ -78,7 +80,8 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
7880
Self::new(precision, scale)
7981
} else {
8082
Self {
81-
precision,
83+
// SAFETY: Caller guarantees precision is non-zero
84+
precision: unsafe { NonZero::new_unchecked(precision) },
8285
scale,
8386
phantom: Default::default(),
8487
}
@@ -88,7 +91,7 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
8891
/// The precision is the number of significant figures that the decimal tracks.
8992
#[inline(always)]
9093
pub fn precision(&self) -> u8 {
91-
self.precision
94+
self.precision.get()
9295
}
9396

9497
/// The scale is the maximum number of digits relative to the decimal point.
@@ -100,14 +103,15 @@ impl<D: NativeDecimalType> PrecisionScale<D> {
100103
/// Validate whether a given value of type `D` fits within the precision and scale.
101104
#[inline]
102105
pub fn is_valid(&self, value: D) -> bool {
103-
self.precision <= D::MAX_PRECISION
104-
&& value >= D::MIN_BY_PRECISION[self.precision as usize]
105-
&& value <= D::MAX_BY_PRECISION[self.precision as usize]
106+
self.precision.get() <= D::MAX_PRECISION
107+
&& value >= D::MIN_BY_PRECISION[self.precision.get() as usize]
108+
&& value <= D::MAX_BY_PRECISION[self.precision.get() as usize]
106109
}
107110
}
108111

109112
impl<D: NativeDecimalType> From<PrecisionScale<D>> for DecimalDType {
110113
fn from(value: PrecisionScale<D>) -> Self {
114+
// SAFETY: precision is already NonZero<u8>, so we can use it directly
111115
DecimalDType {
112116
precision: value.precision,
113117
scale: value.scale,
@@ -119,7 +123,7 @@ impl<D: NativeDecimalType> TryFrom<&DecimalDType> for PrecisionScale<D> {
119123
type Error = vortex_error::VortexError;
120124

121125
fn try_from(value: &DecimalDType) -> VortexResult<Self> {
122-
PrecisionScale::try_new(value.precision, value.scale)
126+
PrecisionScale::try_new(value.precision(), value.scale)
123127
}
124128
}
125129

0 commit comments

Comments
 (0)