Skip to content

Commit 65e0571

Browse files
authored
fix: RunEndBool array take respects validity (#1684)
1 parent 18268bc commit 65e0571

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

encodings/runend-bool/src/compute/mod.rs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
mod invert;
22

3+
use arrow_buffer::BooleanBuffer;
34
use vortex_array::array::BoolArray;
45
use vortex_array::compute::{slice, ComputeVTable, InvertFn, ScalarAtFn, SliceFn, TakeFn};
56
use vortex_array::variants::PrimitiveArrayTrait;
6-
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
7+
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};
78
use vortex_dtype::match_each_integer_ptype;
89
use vortex_error::{vortex_bail, VortexResult};
910
use vortex_scalar::Scalar;
@@ -53,10 +54,15 @@ impl TakeFn<RunEndBoolArray> for RunEndBoolEncoding {
5354
.collect::<VortexResult<Vec<_>>>()?
5455
});
5556
let start = array.start();
56-
Ok(
57-
BoolArray::from_iter(physical_indices.iter().map(|&it| value_at_index(it, start)))
58-
.to_array(),
57+
BoolArray::try_new(
58+
BooleanBuffer::from_iter(
59+
physical_indices
60+
.into_iter()
61+
.map(|it| value_at_index(it, start)),
62+
),
63+
array.validity().take(indices)?,
5964
)
65+
.map(|a| a.into_array())
6066
}
6167
}
6268

@@ -90,9 +96,11 @@ impl SliceFn<RunEndBoolArray> for RunEndBoolEncoding {
9096

9197
#[cfg(test)]
9298
mod tests {
93-
use vortex_array::compute::{scalar_at, slice};
99+
use arrow_buffer::BooleanBuffer;
100+
use vortex_array::array::PrimitiveArray;
101+
use vortex_array::compute::{scalar_at, slice, take};
94102
use vortex_array::validity::Validity;
95-
use vortex_array::{ArrayLen, IntoArrayData};
103+
use vortex_array::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant};
96104
use vortex_dtype::Nullability;
97105
use vortex_scalar::Scalar;
98106

@@ -124,4 +132,24 @@ mod tests {
124132
Scalar::bool(false, Nullability::Nullable)
125133
);
126134
}
135+
136+
#[test]
137+
fn take_nullable() {
138+
let re_array = RunEndBoolArray::try_new(
139+
vec![7_u64, 10].into_array(),
140+
false,
141+
Validity::from(BooleanBuffer::from(vec![
142+
false, false, true, true, true, true, true, true, false, false,
143+
])),
144+
)
145+
.unwrap();
146+
147+
let taken = take(&re_array, PrimitiveArray::from(vec![6, 9])).unwrap();
148+
let taken_bool = taken.into_bool().unwrap();
149+
assert_eq!(taken_bool.dtype(), re_array.dtype());
150+
assert_eq!(
151+
taken_bool.boolean_buffer(),
152+
BooleanBuffer::from(vec![false, true])
153+
);
154+
}
127155
}

vortex-array/src/array/bool/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::sync::Arc;
33

44
use arrow_array::BooleanArray;
55
use arrow_buffer::{BooleanBufferBuilder, MutableBuffer};
6-
use itertools::Itertools;
76
use serde::{Deserialize, Serialize};
87
use vortex_buffer::Buffer;
98
use vortex_dtype::{DType, Nullability};
@@ -129,7 +128,7 @@ impl BoolArray {
129128
first_byte_bit_offset,
130129
}),
131130
Some(Buffer::from(inner)),
132-
validity.into_array().into_iter().collect_vec().into(),
131+
validity.into_array().into_iter().collect(),
133132
StatsSet::default(),
134133
)?
135134
.try_into()

vortex-array/src/compute/take.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use log::info;
21
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
32

43
use crate::encoding::Encoding;
@@ -72,17 +71,26 @@ pub fn take(
7271
// If TakeFn defined for the encoding, delegate to TakeFn.
7372
// If we know from stats that indices are all valid, we can avoid all bounds checks.
7473
if let Some(take_fn) = array.encoding().take_fn() {
75-
return if checked_indices {
74+
let result = if checked_indices {
7675
// SAFETY: indices are all inbounds per stats.
7776
// TODO(aduffy): this means stats must be trusted, can still trigger UB if stats are bad.
7877
unsafe { take_fn.take_unchecked(array, indices) }
7978
} else {
8079
take_fn.take(array, indices)
81-
};
80+
}?;
81+
if array.dtype() != result.dtype() {
82+
vortex_bail!(
83+
"TakeFn {} changed array dtype from {} to {}",
84+
array.encoding().id(),
85+
array.dtype(),
86+
result.dtype()
87+
);
88+
}
89+
return Ok(result);
8290
}
8391

8492
// Otherwise, flatten and try again.
85-
info!("TakeFn not implemented for {}, flattening", array);
93+
log::debug!("No take implementation found for {}", array.encoding().id());
8694
let canonical = array.clone().into_canonical()?.into_array();
8795
let canonical_take_fn = canonical
8896
.encoding()

0 commit comments

Comments
 (0)