Skip to content

Commit 3aa9818

Browse files
authored
feat: teach PrimitiveArray fill_null (#2006)
1 parent 5556cb9 commit 3aa9818

File tree

3 files changed

+121
-8
lines changed

3 files changed

+121
-8
lines changed

vortex-array/src/array/primitive/compute/fill.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use vortex_buffer::Buffer;
22
use vortex_dtype::{match_each_native_ptype, Nullability};
33
use vortex_error::{vortex_err, VortexResult};
4+
use vortex_scalar::Scalar;
45

56
use crate::array::primitive::PrimitiveArray;
6-
use crate::array::PrimitiveEncoding;
7+
use crate::array::{ConstantArray, PrimitiveEncoding};
78
use crate::compute::FillForwardFn;
89
use crate::validity::{ArrayValidity, Validity};
910
use crate::variants::PrimitiveArrayTrait;
@@ -27,10 +28,8 @@ impl FillForwardFn<PrimitiveArray> for PrimitiveEncoding {
2728

2829
if validity.all_invalid() {
2930
match_each_native_ptype!(array.ptype(), |$T| {
30-
return Ok(PrimitiveArray::new(
31-
Buffer::<$T>::zeroed(array.len()),
32-
Validity::AllValid
33-
).into_array());
31+
let fill_value = Scalar::from($T::default()).cast(array.dtype())?;
32+
return Ok(ConstantArray::new(fill_value, array.len()).into_array())
3433
})
3534
}
3635

@@ -90,12 +89,12 @@ mod test {
9089
#[test]
9190
fn nullable_non_null() {
9291
let arr = PrimitiveArray::new(
93-
buffer![8u8, 10u8, 12u8, 14u8, 16u8],
92+
buffer![8u8, 10, 12, 14, 16],
9493
Validity::Array(BoolArray::from_iter([true, true, true, true, true]).into_array()),
9594
)
9695
.into_array();
9796
let p = fill_forward(&arr).unwrap().into_primitive().unwrap();
98-
assert_eq!(p.as_slice::<u8>(), vec![8, 10, 12, 14, 16]);
97+
assert_eq!(p.as_slice::<u8>(), vec![8u8, 10, 12, 14, 16]);
9998
assert!(p.logical_validity().all_valid());
10099
}
101100
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use std::ops::Not;
2+
3+
use vortex_buffer::BufferMut;
4+
use vortex_dtype::{match_each_native_ptype, Nullability};
5+
use vortex_error::{VortexExpect, VortexResult};
6+
use vortex_scalar::Scalar;
7+
8+
use crate::array::primitive::PrimitiveArray;
9+
use crate::array::{ConstantArray, PrimitiveEncoding};
10+
use crate::compute::FillNullFn;
11+
use crate::validity::Validity;
12+
use crate::variants::PrimitiveArrayTrait;
13+
use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant as _};
14+
15+
impl FillNullFn<PrimitiveArray> for PrimitiveEncoding {
16+
fn fill_null(&self, array: &PrimitiveArray, fill_value: Scalar) -> VortexResult<ArrayData> {
17+
let result_validity = match fill_value.dtype().nullability() {
18+
Nullability::NonNullable => Validity::NonNullable,
19+
Nullability::Nullable => Validity::AllValid,
20+
};
21+
22+
Ok(match array.validity() {
23+
Validity::NonNullable | Validity::AllValid => {
24+
match_each_native_ptype!(array.ptype(), |$T| {
25+
PrimitiveArray::new::<$T>(array.buffer().clone(), result_validity).into_array()
26+
})
27+
}
28+
Validity::AllInvalid => ConstantArray::new(fill_value, array.len()).into_array(),
29+
Validity::Array(is_valid) => {
30+
// TODO(danking): when we take PrimitiveArray by value, we should mutate in-place
31+
let is_invalid = is_valid.into_bool()?.boolean_buffer().not();
32+
match_each_native_ptype!(array.ptype(), |$T| {
33+
let mut buffer = BufferMut::copy_from(array.as_slice::<$T>());
34+
let fill_value = fill_value
35+
.as_primitive()
36+
.typed_value::<$T>()
37+
.vortex_expect("top-level fill_null ensure non-null fill value");
38+
for invalid_index in is_invalid.set_indices() {
39+
buffer[invalid_index] = fill_value;
40+
}
41+
PrimitiveArray::new(buffer, result_validity).into_array()
42+
})
43+
}
44+
})
45+
}
46+
}
47+
48+
#[cfg(test)]
49+
mod test {
50+
use vortex_buffer::buffer;
51+
use vortex_scalar::Scalar;
52+
53+
use crate::array::primitive::PrimitiveArray;
54+
use crate::array::BoolArray;
55+
use crate::compute::fill_null;
56+
use crate::validity::{ArrayValidity, Validity};
57+
use crate::{IntoArrayData, IntoArrayVariant};
58+
59+
#[test]
60+
fn fill_null_leading_none() {
61+
let arr =
62+
PrimitiveArray::from_option_iter([None, Some(8u8), None, Some(10), None]).into_array();
63+
let p = fill_null(&arr, Scalar::from(42u8))
64+
.unwrap()
65+
.into_primitive()
66+
.unwrap();
67+
assert_eq!(p.as_slice::<u8>(), vec![42, 8, 42, 10, 42]);
68+
assert!(p.logical_validity().all_valid());
69+
}
70+
71+
#[test]
72+
fn fill_null_all_none() {
73+
let arr = PrimitiveArray::from_option_iter([Option::<u8>::None, None, None, None, None])
74+
.into_array();
75+
76+
let p = fill_null(&arr, Scalar::from(255u8))
77+
.unwrap()
78+
.into_primitive()
79+
.unwrap();
80+
assert_eq!(p.as_slice::<u8>(), vec![255, 255, 255, 255, 255]);
81+
assert!(p.logical_validity().all_valid());
82+
}
83+
84+
#[test]
85+
fn fill_null_nullable_non_null() {
86+
let arr = PrimitiveArray::new(
87+
buffer![8u8, 10, 12, 14, 16],
88+
Validity::Array(BoolArray::from_iter([true, true, true, true, true]).into_array()),
89+
)
90+
.into_array();
91+
let p = fill_null(&arr, Scalar::from(255u8))
92+
.unwrap()
93+
.into_primitive()
94+
.unwrap();
95+
assert_eq!(p.as_slice::<u8>(), vec![8, 10, 12, 14, 16]);
96+
assert!(p.logical_validity().all_valid());
97+
}
98+
99+
#[test]
100+
fn fill_null_non_nullable() {
101+
let arr = buffer![8u8, 10, 12, 14, 16].into_array();
102+
let p = fill_null(&arr, Scalar::from(255u8))
103+
.unwrap()
104+
.into_primitive()
105+
.unwrap();
106+
assert_eq!(p.as_slice::<u8>(), vec![8u8, 10, 12, 14, 16]);
107+
assert!(p.logical_validity().all_valid());
108+
}
109+
}

vortex-array/src/array/primitive/compute/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use crate::array::PrimitiveEncoding;
22
use crate::compute::{
3-
CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn,
3+
CastFn, ComputeVTable, FillForwardFn, FillNullFn, FilterFn, ScalarAtFn, SearchSortedFn,
44
SearchSortedUsizeFn, SliceFn, TakeFn,
55
};
66
use crate::ArrayData;
77

88
mod cast;
99
mod fill;
10+
mod fill_null;
1011
mod filter;
1112
mod scalar_at;
1213
mod search_sorted;
@@ -22,6 +23,10 @@ impl ComputeVTable for PrimitiveEncoding {
2223
Some(self)
2324
}
2425

26+
fn fill_null_fn(&self) -> Option<&dyn FillNullFn<ArrayData>> {
27+
Some(self)
28+
}
29+
2530
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
2631
Some(self)
2732
}

0 commit comments

Comments
 (0)