Skip to content

Commit ea4a432

Browse files
authored
Fix take with nullable indices on arrays with patches (#2336)
If the array is nonnullable after a take it might end up being nullable, however, same take might not affect the values of patches thus resulting in nonullable patches. Need to ensure that after take the patches values have same nullability as the target array
1 parent 9d1e36f commit ea4a432

File tree

5 files changed

+111
-29
lines changed

5 files changed

+111
-29
lines changed

encodings/alp/src/alp/compute/mod.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ use vortex_scalar::Scalar;
1313
use crate::{match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPFloat};
1414

1515
impl ComputeVTable for ALPEncoding {
16+
fn compare_fn(&self) -> Option<&dyn CompareFn<Array>> {
17+
Some(self)
18+
}
19+
1620
fn filter_fn(&self) -> Option<&dyn FilterFn<Array>> {
1721
Some(self)
1822
}
@@ -28,10 +32,6 @@ impl ComputeVTable for ALPEncoding {
2832
fn take_fn(&self) -> Option<&dyn TakeFn<Array>> {
2933
Some(self)
3034
}
31-
32-
fn compare_fn(&self) -> Option<&dyn CompareFn<Array>> {
33-
Some(self)
34-
}
3535
}
3636

3737
impl ScalarAtFn<ALPArray> for ALPEncoding {
@@ -60,16 +60,21 @@ impl ScalarAtFn<ALPArray> for ALPEncoding {
6060

6161
impl TakeFn<ALPArray> for ALPEncoding {
6262
fn take(&self, array: &ALPArray, indices: &Array) -> VortexResult<Array> {
63-
Ok(ALPArray::try_new(
64-
take(array.encoded(), indices)?,
65-
array.exponents(),
66-
array
67-
.patches()
68-
.map(|p| p.take(indices))
69-
.transpose()?
70-
.flatten(),
71-
)?
72-
.into_array())
63+
let taken_encoded = take(array.encoded(), indices)?;
64+
let taken_patches = array
65+
.patches()
66+
.map(|p| p.take(indices))
67+
.transpose()?
68+
.flatten()
69+
.map(|p| {
70+
p.cast_values(
71+
&array
72+
.dtype()
73+
.with_nullability(taken_encoded.dtype().nullability()),
74+
)
75+
})
76+
.transpose()?;
77+
Ok(ALPArray::try_new(taken_encoded, array.exponents(), taken_patches)?.into_array())
7378
}
7479
}
7580

encodings/alp/src/alp_rd/array.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ impl ALPRDArray {
7171
PType::try_from(left_parts.dtype()).vortex_expect("left_parts dtype must be uint");
7272

7373
// we enforce right_parts to be non-nullable uint
74-
if right_parts.dtype().is_nullable() {
75-
vortex_bail!("right_parts dtype must be non-nullable");
76-
}
7774
if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
7875
vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
7976
}
@@ -82,11 +79,11 @@ impl ALPRDArray {
8279

8380
let patches = left_parts_patches
8481
.map(|patches| {
85-
if patches.values().dtype().is_nullable() {
86-
vortex_bail!("patches must be non-nullable: {}", patches.values());
82+
if !patches.values().all_valid()? {
83+
vortex_bail!("patches must be all valid: {}", patches.values());
8784
}
88-
let metadata =
89-
patches.to_metadata(left_parts.len(), &left_parts.dtype().as_nonnullable());
85+
let patches = patches.cast_values(left_parts.dtype())?;
86+
let metadata = patches.to_metadata(left_parts.len(), left_parts.dtype());
9087
let (_, _, indices, values) = patches.into_parts();
9188
children.push(indices);
9289
children.push(values);
@@ -146,7 +143,7 @@ impl ALPRDArray {
146143
/// The dtype of the patches of the left parts of the array.
147144
#[inline]
148145
fn left_parts_patches_dtype(&self) -> DType {
149-
DType::Primitive(self.metadata().left_parts_ptype, Nullability::NonNullable)
146+
DType::Primitive(self.metadata().left_parts_ptype, self.dtype().nullability())
150147
}
151148

152149
/// The leftmost (most significant) bits of the floating point values stored in the array.

encodings/alp/src/alp_rd/compute/take.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,31 @@
1-
use vortex_array::compute::{take, TakeFn};
1+
use vortex_array::compute::{fill_null, take, TakeFn};
22
use vortex_array::{Array, IntoArray};
33
use vortex_error::VortexResult;
4+
use vortex_scalar::{Scalar, ScalarValue};
45

56
use crate::{ALPRDArray, ALPRDEncoding};
67

78
impl TakeFn<ALPRDArray> for ALPRDEncoding {
89
fn take(&self, array: &ALPRDArray, indices: &Array) -> VortexResult<Array> {
10+
let taken_left_parts = take(array.left_parts(), indices)?;
911
let left_parts_exceptions = array
1012
.left_parts_patches()
1113
.map(|patches| patches.take(indices))
1214
.transpose()?
13-
.flatten();
15+
.flatten()
16+
.map(|p| {
17+
let values_dtype = p
18+
.values()
19+
.dtype()
20+
.with_nullability(taken_left_parts.dtype().nullability());
21+
p.cast_values(&values_dtype)
22+
})
23+
.transpose()?;
1424

15-
let taken_left_parts = take(array.left_parts(), indices)?;
25+
let right_parts = fill_null(
26+
take(array.right_parts(), indices)?,
27+
Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
28+
)?;
1629
Ok(ALPRDArray::try_new(
1730
if taken_left_parts.dtype().is_nullable() {
1831
array.dtype().as_nullable()
@@ -21,7 +34,7 @@ impl TakeFn<ALPRDArray> for ALPRDEncoding {
2134
},
2235
taken_left_parts,
2336
array.left_parts_dict(),
24-
take(array.right_parts(), indices)?,
37+
right_parts,
2538
array.right_bit_width(),
2639
left_parts_exceptions,
2740
)?
@@ -59,4 +72,31 @@ mod test {
5972

6073
assert_eq!(taken.as_slice::<T>(), &[a, outlier]);
6174
}
75+
76+
#[rstest]
77+
#[case(0.1f32, 0.2f32, 3e25f32)]
78+
#[case(0.1f64, 0.2f64, 3e100f64)]
79+
fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
80+
let array = PrimitiveArray::from_iter([a, b, outlier]);
81+
let encoded = RDEncoder::new(&[a, b]).encode(&array);
82+
83+
assert!(encoded.left_parts_patches().is_some());
84+
assert!(encoded
85+
.left_parts_patches()
86+
.unwrap()
87+
.dtype()
88+
.is_unsigned_int());
89+
90+
let taken = take(
91+
encoded.as_ref(),
92+
PrimitiveArray::from_option_iter([Some(0), Some(2), None]).as_ref(),
93+
)
94+
.unwrap()
95+
.into_primitive()
96+
.unwrap();
97+
98+
assert_eq!(taken.as_slice::<T>()[0], a);
99+
assert_eq!(taken.as_slice::<T>()[1], outlier);
100+
assert!(!taken.validity_mask().unwrap().value(2));
101+
}
62102
}

encodings/fastlanes/src/bitpacking/compute/take.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ fn take_primitive<T: NativePType + BitPacking, I: NativePType>(
113113
}
114114
if let Some(patches) = array.patches() {
115115
if let Some(patches) = patches.take(indices)? {
116-
return unpatched_taken.patch(patches);
116+
let cast_patches = patches.cast_values(unpatched_taken.dtype())?;
117+
return unpatched_taken.patch(cast_patches);
117118
}
118119
}
119120

@@ -131,6 +132,7 @@ mod test {
131132
use vortex_array::{IntoArray, IntoArrayVariant};
132133
use vortex_buffer::{buffer, Buffer};
133134

135+
use crate::bitpacking::compute::take::take_primitive;
134136
use crate::BitPackedArray;
135137

136138
#[test]
@@ -217,12 +219,28 @@ mod test {
217219
let start =
218220
BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();
219221

220-
let taken_primitive = super::take_primitive::<u32, u64>(
222+
let taken_primitive = take_primitive::<u32, u64>(
221223
&start,
222224
&PrimitiveArray::from_iter([0u64, 1, 2, 3]),
223225
Validity::NonNullable,
224226
)
225227
.unwrap();
226228
assert_eq!(taken_primitive.as_slice::<i32>(), &[1i32, 2, 3, 4]);
227229
}
230+
231+
#[test]
232+
fn take_nullable_with_nullables() {
233+
let start =
234+
BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();
235+
236+
let taken_primitive = take(
237+
&start,
238+
PrimitiveArray::from_option_iter([Some(0u64), Some(1), None, Some(3)]),
239+
)
240+
.unwrap()
241+
.into_primitive()
242+
.unwrap();
243+
assert_eq!(taken_primitive.as_slice::<i32>(), &[1i32, 2, 1, 4]);
244+
assert_eq!(taken_primitive.invalid_count().unwrap(), 1);
245+
}
228246
}

vortex-array/src/patches.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::aliases::hash_map::HashMap;
1616
use crate::array::PrimitiveArray;
1717
use crate::compute::{
1818
filter, scalar_at, search_sorted, search_sorted_usize, search_sorted_usize_many, slice, take,
19-
SearchResult, SearchSortedSide,
19+
try_cast, SearchResult, SearchSortedSide,
2020
};
2121
use crate::variants::PrimitiveArrayTrait;
2222
use crate::{Array, IntoArray, IntoArrayVariant};
@@ -110,6 +110,19 @@ impl Patches {
110110
offset,
111111
array_len
112112
);
113+
Self::new_unchecked(array_len, offset, indices, values)
114+
}
115+
116+
/// Construct new patches without validating any of the arguments
117+
///
118+
/// # Safety
119+
///
120+
/// Users have to assert that
121+
/// * Indices and values have the same length
122+
/// * Indices is an unsigned integer type
123+
/// * Indices must be sorted
124+
/// * Last value in indices is smaller than array_len
125+
pub fn new_unchecked(array_len: usize, offset: usize, indices: Array, values: Array) -> Self {
113126
Self {
114127
array_len,
115128
offset,
@@ -180,6 +193,15 @@ impl Patches {
180193
})
181194
}
182195

196+
pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
197+
Ok(Self::new_unchecked(
198+
self.array_len,
199+
self.offset,
200+
self.indices,
201+
try_cast(self.values, values_dtype)?,
202+
))
203+
}
204+
183205
/// Get the patched value at a given index if it exists.
184206
pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
185207
if let Some(patch_idx) = self.search_index(index)?.to_found() {

0 commit comments

Comments
 (0)