Skip to content

Commit 8a8b7c1

Browse files
committed
leaky decimal fixes not implemented
Signed-off-by: Connor Tsui <[email protected]>
1 parent c7b0ef9 commit 8a8b7c1

File tree

4 files changed

+172
-33
lines changed

4 files changed

+172
-33
lines changed

vortex-array/src/arrays/decimal/array.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ impl DecimalArray {
120120
decimal_dtype: DecimalDType,
121121
validity: Validity,
122122
) -> VortexResult<Self> {
123-
Self::validate(&buffer, &validity)?;
123+
Self::validate(&buffer, decimal_dtype, &validity)?;
124124

125125
// SAFETY: validate ensures all invariants are met.
126126
Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
@@ -136,16 +136,18 @@ impl DecimalArray {
136136
///
137137
/// The caller must ensure all of the following invariants are satisfied:
138138
///
139+
/// - The storage type `T` must be compatible with the precision (i.e., able to represent all
140+
/// values of the declared precision).
139141
/// - All non-null values in `buffer` must be representable within the specified precision.
140-
/// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
142+
/// For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
141143
/// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
142144
pub unsafe fn new_unchecked<T: NativeDecimalType>(
143145
buffer: Buffer<T>,
144146
decimal_dtype: DecimalDType,
145147
validity: Validity,
146148
) -> Self {
147149
#[cfg(debug_assertions)]
148-
Self::validate(&buffer, &validity)
150+
Self::validate(&buffer, decimal_dtype, &validity)
149151
.vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
150152

151153
Self {
@@ -162,8 +164,18 @@ impl DecimalArray {
162164
/// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
163165
pub fn validate<T: NativeDecimalType>(
164166
buffer: &Buffer<T>,
167+
// TODO(connor): The decimal array storage type should be able to represent the entire
168+
// domain of the decimal type.
169+
_decimal_dtype: DecimalDType,
165170
validity: &Validity,
166171
) -> VortexResult<()> {
172+
// vortex_ensure!(
173+
// T::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype),
174+
// "Storage type {:?} cannot represent all values of precision {}",
175+
// T::DECIMAL_TYPE,
176+
// decimal_dtype.precision()
177+
// );
178+
167179
if let Some(len) = validity.maybe_len() {
168180
vortex_ensure!(
169181
buffer.len() == len,

vortex-array/src/arrays/decimal/compute/fill_null.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ impl FillNullKernel for DecimalVTable {
2828
let is_invalid = is_valid.to_bool().bit_buffer().not();
2929
match_each_decimal_value_type!(array.values_type(), |T| {
3030
let mut buffer = array.buffer::<T>().into_mut();
31-
let fill_value = fill_value
32-
.as_decimal()
31+
let decimal_scalar = fill_value.as_decimal();
32+
let decimal_value = decimal_scalar
3333
.decimal_value()
34-
.and_then(|v| v.cast::<T>())
35-
.vortex_expect("top-level fill_null ensure non-null fill value");
34+
.vortex_expect("fill_null requires a non-null fill value");
35+
let fill_value = decimal_value
36+
.cast::<T>()
37+
.vortex_expect("fill value does not fit in array's decimal storage type");
3638
for invalid_index in is_invalid.set_indices() {
3739
buffer[invalid_index] = fill_value;
3840
}

vortex-array/src/compute/fill_null.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use std::sync::LazyLock;
55

66
use arcref::ArcRef;
77
use vortex_dtype::DType;
8+
use vortex_dtype::DecimalType;
9+
use vortex_dtype::match_each_decimal_value_type;
810
use vortex_error::VortexError;
11+
use vortex_error::VortexExpect;
912
use vortex_error::VortexResult;
1013
use vortex_error::vortex_bail;
1114
use vortex_error::vortex_err;
@@ -15,6 +18,7 @@ use crate::Array;
1518
use crate::ArrayRef;
1619
use crate::IntoArray;
1720
use crate::arrays::ConstantArray;
21+
use crate::arrays::DecimalVTable;
1822
use crate::compute::ComputeFn;
1923
use crate::compute::ComputeFnVTable;
2024
use crate::compute::InvocationArgs;
@@ -59,6 +63,14 @@ pub fn fill_null(array: &dyn Array, fill_value: &Scalar) -> VortexResult<ArrayRe
5963
}
6064

6165
pub trait FillNullKernel: VTable {
66+
/// Kernel for replacing null values in an array with a fill value.
67+
///
68+
/// TODO(connor): Actually enforce these constraints (so that casts do not fail).
69+
///
70+
/// Implementations can assume that:
71+
/// - The array has at least one null value (not all valid, not all invalid).
72+
/// - The fill value is non-null.
73+
/// - For decimal arrays, the fill value can be successfully cast to the array's storage type.
6274
fn fill_null(&self, array: &Self::Array, fill_value: &Scalar) -> VortexResult<ArrayRef>;
6375
}
6476

@@ -110,6 +122,34 @@ impl ComputeFnVTable for FillNull {
110122
vortex_bail!("Cannot fill_null with a null value")
111123
}
112124

125+
/*
126+
// For decimal arrays, validate that the fill value fits in the storage type.
127+
if let Some(decimal_dtype) = array.dtype().as_decimal_opt() {
128+
// Try to get the actual storage type from a DecimalArray. Otherwise, use the smallest
129+
// type that can represent the precision.
130+
let storage_type = array
131+
.as_opt::<DecimalVTable>()
132+
.map(|arr| arr.values_type())
133+
.unwrap_or_else(|| DecimalType::smallest_decimal_value_type(decimal_dtype));
134+
let decimal_value = fill_value
135+
.as_decimal()
136+
.decimal_value()
137+
.vortex_expect("fill_null checked is_null above");
138+
139+
let fits = match_each_decimal_value_type!(storage_type, |T| {
140+
decimal_value.cast::<T>().is_some()
141+
});
142+
143+
if !fits {
144+
vortex_bail!(
145+
"fill value {} does not fit in array's decimal storage type {:?}",
146+
decimal_value,
147+
storage_type
148+
)
149+
}
150+
}
151+
*/
152+
113153
for kernel in kernels {
114154
if let Some(output) = kernel.invoke(args)? {
115155
return Ok(output);

vortex-array/src/validity.rs

Lines changed: 111 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -272,37 +272,47 @@ impl Validity {
272272
indices: &dyn Array,
273273
patches: &Validity,
274274
) -> Self {
275+
use Validity::*;
276+
275277
match (&self, patches) {
276-
(Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
277-
(Validity::NonNullable, _) => {
278-
vortex_panic!("Can't patch a non-nullable validity with nullable validity")
278+
(NonNullable, NonNullable | AllValid) => {
279+
return NonNullable;
280+
}
281+
(NonNullable, Array(_) | AllInvalid) => {
282+
vortex_panic!("Can't patch a non-nullable validity with null values")
279283
}
280-
(_, Validity::NonNullable) => {
281-
vortex_panic!("Can't patch a nullable validity with non-nullable validity")
284+
285+
(AllValid | Array(_) | AllInvalid, NonNullable) => {
286+
vortex_panic!("Can't patch a nullable validity with a non-nullable validity")
282287
}
283-
(Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
284-
(Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
285-
_ => {}
288+
289+
(AllValid, AllValid) => return AllValid,
290+
(AllValid, Array(_) | AllInvalid) => {}
291+
292+
(AllInvalid, AllInvalid) => return AllInvalid,
293+
(AllInvalid, AllValid | Array(_)) => {}
294+
295+
(Array(_), _) => {}
286296
};
287297

288-
let own_nullability = if self == Validity::NonNullable {
298+
let own_nullability = if self == NonNullable {
289299
Nullability::NonNullable
290300
} else {
291301
Nullability::Nullable
292302
};
293303

294304
let source = match self {
295-
Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296-
Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297-
Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298-
Validity::Array(a) => a.to_bool(),
305+
NonNullable => BoolArray::from(BitBuffer::new_set(len)),
306+
AllValid => BoolArray::from(BitBuffer::new_set(len)),
307+
AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
308+
Array(a) => a.to_bool(),
299309
};
300310

301311
let patch_values = match patches {
302-
Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303-
Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304-
Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305-
Validity::Array(a) => a.to_bool(),
312+
NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
313+
AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
314+
AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
315+
Array(a) => a.to_bool(),
306316
};
307317

308318
let patches = Patches::new(
@@ -513,21 +523,96 @@ mod tests {
513523
use crate::validity::Validity;
514524

515525
#[rstest]
516-
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
517-
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
526+
#[case(
527+
Validity::AllValid,
528+
5,
529+
&[2, 4],
530+
Validity::AllValid,
531+
Validity::AllValid
532+
)]
533+
#[case(
534+
Validity::AllValid,
535+
5,
536+
&[2, 4],
537+
Validity::AllInvalid,
538+
Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
539+
)]
540+
#[case(
541+
Validity::AllValid,
542+
5,
543+
&[2, 4],
544+
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
545+
Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
546+
)]
547+
#[case(
548+
Validity::AllInvalid,
549+
5,
550+
&[2, 4],
551+
Validity::AllValid,
552+
Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
553+
)]
554+
#[case(
555+
Validity::AllInvalid,
556+
5,
557+
&[2, 4],
558+
Validity::AllInvalid,
559+
Validity::AllInvalid
560+
)]
561+
#[case(
562+
Validity::AllInvalid,
563+
5,
564+
&[2, 4],
565+
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
566+
Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
567+
)]
568+
#[case(
569+
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
570+
5,
571+
&[2, 4],
572+
Validity::AllValid,
573+
Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
518574
)]
519-
#[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
575+
#[case(
576+
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
577+
5,
578+
&[2, 4],
579+
Validity::AllInvalid,
580+
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
520581
)]
521-
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
582+
#[case(
583+
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
584+
5,
585+
&[2, 4],
586+
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
587+
Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
522588
)]
523-
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
524-
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
589+
#[case(
590+
Validity::NonNullable,
591+
5,
592+
&[2, 4],
593+
Validity::AllValid,
594+
Validity::NonNullable
525595
)]
526-
#[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
596+
#[case(
597+
Validity::AllValid,
598+
5,
599+
&[2, 4],
600+
Validity::NonNullable,
601+
Validity::AllValid
527602
)]
528-
#[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
603+
#[case(
604+
Validity::AllInvalid,
605+
5,
606+
&[2, 4],
607+
Validity::NonNullable,
608+
Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
529609
)]
530-
#[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
610+
#[case(
611+
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
612+
5,
613+
&[2, 4],
614+
Validity::NonNullable,
615+
Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
531616
)]
532617
fn patch_validity(
533618
#[case] validity: Validity,

0 commit comments

Comments
 (0)