Skip to content

Commit d8d67f2

Browse files
authored
fix: teach SparseArray to account for null values when the fill is null (#3846)
1 parent 815b135 commit d8d67f2

File tree

3 files changed

+66
-36
lines changed

3 files changed

+66
-36
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/sparse/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ workspace = true
1818

1919
[dependencies]
2020
itertools = { workspace = true }
21+
num-traits = { workspace = true }
2122
prost = { workspace = true }
2223
rstest_reuse = { workspace = true }
2324
vortex-array = { workspace = true }

encodings/sparse/src/lib.rs

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
use std::fmt::Debug;
55

6+
use itertools::Itertools as _;
7+
use num_traits::NumCast;
68
use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
79
use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
810
use vortex_array::patches::Patches;
911
use vortex_array::stats::{ArrayStats, StatsSetRef};
1012
use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
1113
use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
1214
use vortex_buffer::Buffer;
13-
use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
15+
use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
1416
use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
1517
use vortex_mask::{AllOr, Mask};
1618
use vortex_scalar::Scalar;
@@ -268,48 +270,60 @@ impl ValidityVTable<SparseVTable> for SparseVTable {
268270

269271
#[allow(clippy::unnecessary_fallible_conversions)]
270272
fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
271-
let indices = array.patches().indices().to_primitive()?;
273+
let fill_is_valid = array.fill_scalar().is_valid();
274+
let values_validity = array.patches().values().validity_mask()?;
275+
let len = array.len();
272276

273-
if array.fill_scalar().is_null() {
274-
// If we have a null fill value, then we set each patch value to true.
275-
let mut buffer = BooleanBufferBuilder::new(array.len());
276-
// TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
277-
buffer.append_n(array.len(), false);
278-
279-
match_each_integer_ptype!(indices.ptype(), |I| {
280-
indices.as_slice::<I>().iter().for_each(|&index| {
281-
buffer.set_bit(
282-
usize::try_from(index).vortex_expect("Failed to cast to usize")
283-
- array.patches().offset(),
284-
true,
285-
);
286-
});
287-
});
288-
289-
return Ok(Mask::from_buffer(buffer.finish()));
277+
if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
278+
return Ok(Mask::AllTrue(len));
279+
}
280+
if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
281+
return Ok(Mask::AllFalse(len));
290282
}
291283

292-
// If the fill_value is non-null, then the validity is based on the validity of the
293-
// patch values.
294-
let mut buffer = BooleanBufferBuilder::new(array.len());
295-
buffer.append_n(array.len(), true);
284+
// TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
285+
let mut is_valid_buffer = BooleanBufferBuilder::new(len);
286+
is_valid_buffer.append_n(len, fill_is_valid);
287+
288+
let indices = array.patches().indices().to_primitive()?;
289+
let index_offset = array.patches().offset();
296290

297-
let values_validity = array.patches().values().validity_mask()?;
298291
match_each_integer_ptype!(indices.ptype(), |I| {
299-
indices
300-
.as_slice::<I>()
301-
.iter()
302-
.enumerate()
303-
.for_each(|(patch_idx, &index)| {
304-
buffer.set_bit(
305-
usize::try_from(index).vortex_expect("Failed to cast to usize")
306-
- array.patches().offset(),
307-
values_validity.value(patch_idx),
308-
);
309-
})
292+
let indices = indices.as_slice::<I>();
293+
patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
310294
});
311295

312-
Ok(Mask::from_buffer(buffer.finish()))
296+
Ok(Mask::from_buffer(is_valid_buffer.finish()))
297+
}
298+
}
299+
300+
fn patch_validity<I: NativePType>(
301+
is_valid_buffer: &mut BooleanBufferBuilder,
302+
indices: &[I],
303+
index_offset: usize,
304+
values_validity: Mask,
305+
) {
306+
let indices = indices.iter().map(|index| {
307+
let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
308+
index - index_offset
309+
});
310+
match values_validity {
311+
Mask::AllTrue(_) => {
312+
for index in indices {
313+
is_valid_buffer.set_bit(index, true);
314+
}
315+
}
316+
Mask::AllFalse(_) => {
317+
for index in indices {
318+
is_valid_buffer.set_bit(index, false);
319+
}
320+
}
321+
Mask::Values(mask_values) => {
322+
let is_valid = mask_values.boolean_buffer().iter();
323+
for (index, is_valid) in indices.zip_eq(is_valid) {
324+
is_valid_buffer.set_bit(index, is_valid);
325+
}
326+
}
313327
}
314328
}
315329

@@ -519,4 +533,18 @@ mod test {
519533
vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
520534
);
521535
}
536+
537+
#[test]
538+
fn validity_mask_includes_null_values_when_fill_is_null() {
539+
let indices = buffer![0u8, 2, 4, 6, 8].into_array();
540+
let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
541+
.into_array();
542+
let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
543+
let actual = array.validity_mask().unwrap();
544+
let expected = Mask::from_iter([
545+
true, false, true, false, false, false, false, false, true, false,
546+
]);
547+
548+
assert_eq!(actual, expected);
549+
}
522550
}

0 commit comments

Comments
 (0)