Skip to content

Commit 000c99a

Browse files
authored
Implement take for Mask and FromIterator<Mask> for Mask (#2689)
1 parent 175883d commit 000c99a

File tree

22 files changed

+96
-119
lines changed

22 files changed

+96
-119
lines changed

encodings/bytebool/src/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ impl ArrayValidityImpl for ByteBoolArray {
117117
}
118118

119119
fn _validity_mask(&self) -> VortexResult<Mask> {
120-
self.validity.to_logical(self.len())
120+
self.validity.to_mask(self.len())
121121
}
122122
}
123123

encodings/datetime-parts/src/canonical.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ mod test {
131131

132132
assert_eq!(
133133
date_times.validity_mask().unwrap(),
134-
validity.to_logical(date_times.len()).unwrap()
134+
validity.to_mask(date_times.len()).unwrap()
135135
);
136136

137137
let primitive_values = decode_to_temporal(&date_times)

encodings/fastlanes/src/bitpacking/compress.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use vortex_dtype::{
1616
match_each_unsigned_integer_ptype,
1717
};
1818
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
19-
use vortex_mask::AllOr;
19+
use vortex_mask::{AllOr, Mask};
2020
use vortex_scalar::Scalar;
2121

2222
use crate::BitPackedArray;
@@ -274,52 +274,52 @@ fn apply_patches<T: NativePType>(dst: &mut UninitRange<T>, patches: &Patches) ->
274274

275275
let indices = indices.to_primitive()?;
276276
let values = values.to_primitive()?;
277-
let validity = values.validity();
277+
let validity = values.validity_mask()?;
278278
let values = values.as_slice::<T>();
279279
match_each_unsigned_integer_ptype!(indices.ptype(), |$P| {
280-
insert_values_and_validity_at_indices::<T, $P>(
280+
insert_values_and_validity_at_indices::<T, _>(
281281
dst,
282-
indices,
282+
indices.as_slice::<$P>(),
283283
values,
284-
validity.clone(),
284+
validity,
285285
indices_offset,
286286
)
287-
})
287+
});
288+
Ok(())
288289
}
289290

290291
fn insert_values_and_validity_at_indices<
291292
T: NativePType,
292293
IndexT: NativePType + AsPrimitive<usize>,
293294
>(
294295
dst: &mut UninitRange<T>,
295-
indices: PrimitiveArray,
296+
indices: &[IndexT],
296297
values: &[T],
297-
validity: Validity,
298+
values_validity: Mask,
298299
indices_offset: usize,
299-
) -> VortexResult<()> {
300-
match validity {
301-
Validity::NonNullable => {
302-
for (compressed_index, decompressed_index) in
303-
indices.as_slice::<IndexT>().iter().enumerate()
304-
{
300+
) {
301+
match values_validity {
302+
Mask::AllTrue(_) => {
303+
for (compressed_index, decompressed_index) in indices.iter().enumerate() {
305304
dst[decompressed_index.as_() - indices_offset] =
306305
MaybeUninit::new(values[compressed_index]);
307306
}
308307
}
309-
_ => {
310-
let validity = validity.to_logical(indices.len())?;
311-
for (compressed_index, decompressed_index) in
312-
indices.as_slice::<IndexT>().iter().enumerate()
313-
{
308+
Mask::AllFalse(_) => {
309+
for (compressed_index, decompressed_index) in indices.iter().enumerate() {
314310
let out_index = decompressed_index.as_() - indices_offset;
315-
dst[decompressed_index.as_() - indices_offset] =
316-
MaybeUninit::new(values[compressed_index]);
317-
dst.set_bit(out_index, validity.value(out_index));
311+
dst[out_index] = MaybeUninit::new(values[compressed_index]);
312+
dst.set_bit(out_index, false);
313+
}
314+
}
315+
Mask::Values(vb) => {
316+
for (compressed_index, decompressed_index) in indices.iter().enumerate() {
317+
let out_index = decompressed_index.as_() - indices_offset;
318+
dst[out_index] = MaybeUninit::new(values[compressed_index]);
319+
dst.set_bit(out_index, vb.value(out_index));
318320
}
319321
}
320322
}
321-
322-
Ok(())
323323
}
324324

325325
fn unpack_values_into<T: NativePType, UnsignedT: NativePType + BitPacking, F, G>(

encodings/fastlanes/src/bitpacking/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ impl ArrayValidityImpl for BitPackedArray {
279279
}
280280

281281
fn _validity_mask(&self) -> VortexResult<Mask> {
282-
self.validity.to_logical(self.len())
282+
self.validity.to_mask(self.len())
283283
}
284284
}
285285

encodings/fastlanes/src/delta/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ impl ArrayValidityImpl for DeltaArray {
250250
}
251251

252252
fn _validity_mask(&self) -> VortexResult<Mask> {
253-
self.validity.to_logical(self.len)
253+
self.validity.to_mask(self.len)
254254
}
255255
}
256256

vortex-array/src/arrays/bool/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ impl ArrayValidityImpl for BoolArray {
152152

153153
#[inline]
154154
fn _validity_mask(&self) -> VortexResult<Mask> {
155-
self.validity.to_logical(self.len())
155+
self.validity.to_mask(self.len())
156156
}
157157
}
158158

vortex-array/src/arrays/bool/compute/take.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use arrow_buffer::BooleanBuffer;
22
use itertools::Itertools;
33
use num_traits::AsPrimitive;
4-
use vortex_dtype::{NativePType, match_each_integer_ptype};
4+
use vortex_dtype::match_each_integer_ptype;
55
use vortex_error::VortexResult;
66
use vortex_mask::Mask;
77
use vortex_scalar::Scalar;
88

9-
use crate::arrays::{BoolArray, BoolEncoding, ConstantArray, PrimitiveArray};
9+
use crate::arrays::{BoolArray, BoolEncoding, ConstantArray};
1010
use crate::builders::ArrayBuilder;
1111
use crate::compute::{TakeFn, fill_null};
1212
use crate::variants::PrimitiveArrayTrait;
@@ -27,7 +27,7 @@ impl TakeFn<&BoolArray> for BoolEncoding {
2727
};
2828
let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
2929
let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |$I| {
30-
take_valid_indices::<$I>(array, &indices_nulls_zeroed)
30+
take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<$I>())
3131
});
3232

3333
Ok(BoolArray::new(buffer, array.validity().take(indices)?).into_array())
@@ -43,30 +43,30 @@ impl TakeFn<&BoolArray> for BoolEncoding {
4343
}
4444
}
4545

46-
fn take_valid_indices<I: AsPrimitive<usize> + NativePType>(
47-
array: &BoolArray,
48-
indices: &PrimitiveArray,
46+
fn take_valid_indices<I: AsPrimitive<usize>>(
47+
bools: &BooleanBuffer,
48+
indices: &[I],
4949
) -> BooleanBuffer {
5050
// For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
5151
// the overhead to convert to a Vec<bool>.
52-
if array.len() <= 4096 {
53-
let bools = array.boolean_buffer().into_iter().collect_vec();
54-
take_byte_bool(bools, indices.as_slice::<I>())
52+
if bools.len() <= 4096 {
53+
let bools = bools.into_iter().collect_vec();
54+
take_byte_bool(bools, indices)
5555
} else {
56-
take_bool(array.boolean_buffer(), indices.as_slice::<I>())
56+
take_bool(bools, indices)
5757
}
5858
}
5959

6060
fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
6161
BooleanBuffer::collect_bool(indices.len(), |idx| {
62-
bools[unsafe { (*indices.get_unchecked(idx)).as_() }]
62+
bools[unsafe { indices.get_unchecked(idx).as_() }]
6363
})
6464
}
6565

6666
fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
6767
BooleanBuffer::collect_bool(indices.len(), |idx| {
6868
// We can always take from the indices unchecked since collect_bool just iterates len.
69-
bools.value(unsafe { (*indices.get_unchecked(idx)).as_() })
69+
bools.value(unsafe { indices.get_unchecked(idx).as_() })
7070
})
7171
}
7272

vortex-array/src/arrays/chunked/mod.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use crate::iter::{ArrayIterator, ArrayIteratorAdapter};
1717
use crate::nbytes::NBytes;
1818
use crate::stats::{ArrayStats, StatsSetRef};
1919
use crate::stream::{ArrayStream, ArrayStreamAdapter};
20-
use crate::validity::Validity;
2120
use crate::vtable::VTableRef;
2221
use crate::{Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, EmptyMetadata, Encoding, IntoArray};
2322

@@ -240,14 +239,10 @@ impl ArrayValidityImpl for ChunkedArray {
240239
}
241240

242241
fn _validity_mask(&self) -> VortexResult<Mask> {
243-
// TODO(ngates): implement FromIterator<LogicalValidity> for LogicalValidity.
244-
// TODO(ngates): or use a boolean array builder?
245-
let validity: Validity = self
246-
.chunks()
242+
self.chunks()
247243
.iter()
248244
.map(|a| a.validity_mask())
249-
.try_collect()?;
250-
validity.to_logical(self.len())
245+
.try_collect()
251246
}
252247
}
253248

vortex-array/src/arrays/list/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ impl ArrayValidityImpl for ListArray {
173173
}
174174

175175
fn _validity_mask(&self) -> VortexResult<Mask> {
176-
self.validity.to_logical(self.len())
176+
self.validity.to_mask(self.len())
177177
}
178178
}
179179

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ impl CastFn<&PrimitiveArray> for PrimitiveEncoding {
2222
} else if new_nullability == Nullability::Nullable {
2323
// from non-nullable to nullable
2424
array.validity().clone().into_nullable()
25-
} else if new_nullability == Nullability::NonNullable
26-
&& array.validity().to_logical(array.len())?.all_true()
27-
{
25+
} else if new_nullability == Nullability::NonNullable && array.validity().all_valid()? {
2826
// from nullable but all valid, to non-nullable
2927
Validity::NonNullable
3028
} else {

0 commit comments

Comments
 (0)