Skip to content

Commit 1ab0455

Browse files
authored
fix: ConstantArray#take handles nullable indices (#2631)
1 parent 153b26f commit 1ab0455

File tree

13 files changed

+127
-43
lines changed

13 files changed

+127
-43
lines changed

fuzz/src/sort.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult<ArrayRef> {
1616
let mut opt_values = bool_array
1717
.boolean_buffer()
1818
.iter()
19-
.zip(
20-
bool_array
21-
.validity_mask()
22-
.vortex_expect("Failed to get logical validity")
23-
.to_boolean_buffer()
24-
.iter(),
25-
)
19+
.zip(bool_array.validity_mask()?.to_boolean_buffer().iter())
2620
.map(|(b, v)| v.then_some(b))
2721
.collect::<Vec<_>>();
2822
opt_values.sort();

vortex-array/src/arrays/constant/compute/mod.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod cast;
44
mod compare;
55
mod invert;
66
mod search_sorted;
7+
mod take;
78

89
use num_traits::{CheckedMul, ToPrimitive};
910
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
@@ -58,15 +59,15 @@ impl ComputeVTable for ConstantEncoding {
5859
Some(self)
5960
}
6061

61-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
62+
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
6263
Some(self)
6364
}
6465

65-
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
66+
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
6667
Some(self)
6768
}
6869

69-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
70+
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
7071
Some(self)
7172
}
7273
}
@@ -77,12 +78,6 @@ impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
7778
}
7879
}
7980

80-
impl TakeFn<&ConstantArray> for ConstantEncoding {
81-
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
82-
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
83-
}
84-
}
85-
8681
impl SliceFn<&ConstantArray> for ConstantEncoding {
8782
fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
8883
Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array())
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use vortex_error::VortexResult;
2+
use vortex_mask::{AllOr, Mask};
3+
use vortex_scalar::Scalar;
4+
5+
use crate::arrays::{ConstantArray, ConstantEncoding};
6+
use crate::builders::builder_with_capacity;
7+
use crate::compute::TakeFn;
8+
use crate::{Array, ArrayRef};
9+
10+
impl TakeFn<&ConstantArray> for ConstantEncoding {
11+
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12+
match indices.validity_mask()?.boolean_buffer() {
13+
AllOr::All => {
14+
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
15+
}
16+
AllOr::None => Ok(ConstantArray::new(
17+
Scalar::null(array.dtype().clone()),
18+
indices.len(),
19+
)
20+
.into_array()),
21+
AllOr::Some(v) => {
22+
let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
23+
24+
if array.scalar().is_null() {
25+
return Ok(arr);
26+
}
27+
28+
let mut result_builder =
29+
builder_with_capacity(&array.dtype().as_nullable(), indices.len());
30+
result_builder.extend_from_array(&arr)?;
31+
result_builder.set_validity(Mask::from_buffer(v.clone()));
32+
Ok(result_builder.finish())
33+
}
34+
}
35+
}
36+
}
37+
38+
#[cfg(test)]
39+
mod tests {
40+
use vortex_buffer::buffer;
41+
use vortex_mask::AllOr;
42+
43+
use crate::arrays::{ConstantArray, PrimitiveArray};
44+
use crate::compute::take;
45+
use crate::validity::Validity;
46+
use crate::{Array, ToCanonical};
47+
48+
#[test]
49+
fn take_nullable_indices() {
50+
let array = ConstantArray::new(42, 10).to_array();
51+
let taken = take(
52+
&array,
53+
&PrimitiveArray::new(
54+
buffer![0, 5, 7],
55+
Validity::from_iter(vec![false, true, false]),
56+
)
57+
.into_array(),
58+
)
59+
.unwrap();
60+
let valid_indices: &[usize] = &[1usize];
61+
assert_eq!(
62+
taken.to_primitive().unwrap().as_slice::<i32>(),
63+
&[42, 42, 42]
64+
);
65+
assert_eq!(
66+
taken.validity_mask().unwrap().indices(),
67+
AllOr::Some(valid_indices)
68+
);
69+
}
70+
}

vortex-array/src/builders/bool.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::any::Any;
33
use arrow_buffer::BooleanBufferBuilder;
44
use vortex_dtype::{DType, Nullability};
55
use vortex_error::{VortexResult, vortex_bail};
6+
use vortex_mask::Mask;
67

78
use crate::arrays::BoolArray;
89
use crate::builders::ArrayBuilder;
@@ -85,6 +86,11 @@ impl ArrayBuilder for BoolBuilder {
8586
Ok(())
8687
}
8788

89+
fn set_validity(&mut self, validity: Mask) {
90+
self.nulls = LazyNullBufferBuilder::new(validity.len());
91+
self.nulls.append_validity_mask(validity);
92+
}
93+
8894
fn finish(&mut self) -> ArrayRef {
8995
assert_eq!(
9096
self.nulls.len(),

vortex-array/src/builders/extension.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::Arc;
33

44
use vortex_dtype::{DType, ExtDType};
55
use vortex_error::{VortexResult, vortex_bail};
6+
use vortex_mask::Mask;
67
use vortex_scalar::ExtScalar;
78

89
use crate::arrays::ExtensionArray;
@@ -82,6 +83,10 @@ impl ArrayBuilder for ExtensionBuilder {
8283
array.storage().append_to_builder(self.storage.as_mut())
8384
}
8485

86+
fn set_validity(&mut self, validity: Mask) {
87+
self.storage.set_validity(validity);
88+
}
89+
8590
fn finish(&mut self) -> ArrayRef {
8691
let storage = self.storage.finish();
8792
ExtensionArray::new(self.ext_dtype(), storage).into_array()

vortex-array/src/builders/lazy_validity_builder.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ impl LazyNullBufferBuilder {
5252
.append_n(n, false);
5353
}
5454

55-
#[allow(dead_code)]
5655
#[inline]
5756
pub fn append_null(&mut self) {
5857
self.materialize_if_needed();
@@ -62,7 +61,6 @@ impl LazyNullBufferBuilder {
6261
.append(false);
6362
}
6463

65-
#[allow(dead_code)]
6664
#[inline]
6765
pub fn append(&mut self, not_null: bool) {
6866
if not_null {

vortex-array/src/builders/list.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::sync::Arc;
44
use vortex_dtype::Nullability::NonNullable;
55
use vortex_dtype::{DType, NativePType, Nullability};
66
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
7+
use vortex_mask::Mask;
78
use vortex_scalar::{BinaryNumericOperator, ListScalar};
89

910
use crate::arrays::{ConstantArray, ListArray, OffsetPType};
@@ -150,6 +151,11 @@ impl<O: OffsetPType> ArrayBuilder for ListBuilder<O> {
150151
Ok(())
151152
}
152153

154+
fn set_validity(&mut self, validity: Mask) {
155+
self.nulls = LazyNullBufferBuilder::new(validity.len());
156+
self.nulls.append_validity_mask(validity);
157+
}
158+
153159
fn finish(&mut self) -> ArrayRef {
154160
assert_eq!(
155161
self.index_builder.len(),

vortex-array/src/builders/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub use primitive::*;
1717
pub use varbinview::*;
1818
use vortex_dtype::{DType, match_each_native_ptype};
1919
use vortex_error::{VortexResult, vortex_bail, vortex_err};
20+
use vortex_mask::Mask;
2021
use vortex_scalar::{
2122
BinaryScalar, BoolScalar, ExtScalar, ListScalar, PrimitiveScalar, Scalar, ScalarValue,
2223
StructScalar, Utf8Scalar,
@@ -57,6 +58,9 @@ pub trait ArrayBuilder: Send {
5758
/// Extends the array with the provided array, canonicalizing if necessary.
5859
fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()>;
5960

61+
/// Override builders validity with the one provided
62+
fn set_validity(&mut self, validity: Mask);
63+
6064
/// Constructs an Array from the builder components.
6165
///
6266
/// # Panics

vortex-array/src/builders/null.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::any::Any;
22

33
use vortex_dtype::DType;
44
use vortex_error::VortexResult;
5+
use vortex_mask::Mask;
56

67
use crate::arrays::NullArray;
78
use crate::builders::ArrayBuilder;
@@ -54,6 +55,10 @@ impl ArrayBuilder for NullBuilder {
5455
Ok(())
5556
}
5657

58+
fn set_validity(&mut self, validity: Mask) {
59+
self.length = validity.len();
60+
}
61+
5762
fn finish(&mut self) -> ArrayRef {
5863
NullArray::new(self.length).into_array()
5964
}

vortex-array/src/builders/primitive.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ impl<T: NativePType> ArrayBuilder for PrimitiveBuilder<T> {
178178
Ok(())
179179
}
180180

181+
fn set_validity(&mut self, validity: Mask) {
182+
self.nulls = LazyNullBufferBuilder::new(validity.len());
183+
self.nulls.append_validity_mask(validity);
184+
}
185+
181186
fn finish(&mut self) -> ArrayRef {
182187
self.finish_into_primitive().into_array()
183188
}

0 commit comments

Comments
 (0)