Skip to content

Commit c1e9448

Browse files
fix[vortex-array]: struct validity + runend validity (#3645)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent a2aab47 commit c1e9448

File tree

4 files changed

+130
-38
lines changed

4 files changed

+130
-38
lines changed

encodings/runend/benches/run_end_filter.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#![allow(clippy::cast_possible_truncation)]
33

44
use divan::Bencher;
5+
use vortex_array::validity::Validity;
56
use vortex_array::{Array, IntoArray};
67
use vortex_buffer::Buffer;
78
use vortex_mask::Mask;
@@ -42,7 +43,9 @@ fn take_indices(bencher: Bencher, (n, run_step, filter_density): (usize, usize,
4243

4344
bencher
4445
.with_inputs(|| (&array, indices))
45-
.bench_refs(|(array, indices)| take_indices_unchecked(array, indices).unwrap());
46+
.bench_refs(|(array, indices)| {
47+
take_indices_unchecked(array, indices, &Validity::NonNullable).unwrap()
48+
});
4649
}
4750

4851
#[divan::bench(args = BENCH_ARGS)]

encodings/runend/src/compute/filter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ impl FilterKernel for RunEndVTable {
2727

2828
if runs_ratio < FILTER_TAKE_THRESHOLD || mask_values.true_count() < 25 {
2929
// This strategy is directly proportional to the number of indices.
30-
take_indices_unchecked(array, mask_values.indices())
30+
take_indices_unchecked(array, mask_values.indices(), &Validity::NonNullable)
3131
} else {
3232
// This strategy ends up being close to fixed cost based on the number of runs,
3333
// rather than the number of indices.

encodings/runend/src/compute/take.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use num_traits::{AsPrimitive, NumCast};
2+
use vortex_array::arrays::PrimitiveArray;
23
use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
34
use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
4-
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5+
use vortex_array::validity::Validity;
6+
use vortex_array::vtable::ValidityHelper;
7+
use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
58
use vortex_buffer::Buffer;
69
use vortex_dtype::match_each_integer_ptype;
710
use vortex_error::{VortexResult, vortex_bail};
@@ -28,7 +31,7 @@ impl TakeKernel for RunEndVTable {
2831
.collect::<VortexResult<Vec<_>>>()?
2932
});
3033

31-
take_indices_unchecked(array, &checked_indices)
34+
take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
3235
}
3336
}
3437

@@ -38,13 +41,15 @@ register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
3841
pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
3942
array: &RunEndArray,
4043
indices: &[T],
44+
validity: &Validity,
4145
) -> VortexResult<ArrayRef> {
4246
let ends = array.ends().to_primitive()?;
4347
let ends_len = ends.len();
4448

49+
// TODO(joe): use the validity mask to skip search sorted.
4550
let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
4651
let end_slices = ends.as_slice::<I>();
47-
indices
52+
let buffer = indices
4853
.iter()
4954
.map(|idx| idx.as_() + array.offset())
5055
.map(|idx| {
@@ -57,18 +62,21 @@ pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
5762
}
5863
})
5964
.map(|result| result.to_ends_index(ends_len) as u64)
60-
.collect::<Buffer<u64>>()
61-
.into_array()
65+
.collect::<Buffer<u64>>();
66+
67+
PrimitiveArray::new(buffer, validity.clone())
6268
});
6369

64-
take(array.values(), &physical_indices)
70+
take(array.values(), physical_indices.as_ref())
6571
}
6672

6773
#[cfg(test)]
6874
mod test {
6975
use vortex_array::arrays::PrimitiveArray;
7076
use vortex_array::compute::take;
7177
use vortex_array::{Array, IntoArray, ToCanonical};
78+
use vortex_dtype::{DType, Nullability, PType};
79+
use vortex_scalar::{Scalar, ScalarValue};
7280

7381
use crate::RunEndArray;
7482

@@ -126,4 +134,25 @@ mod test {
126134
assert_eq!(taken.scalar_at(1).unwrap(), 2.into());
127135
assert_eq!(taken.scalar_at(2).unwrap(), 5.into());
128136
}
137+
138+
#[test]
139+
fn ree_take_nullable() {
140+
let taken = take(
141+
ree_array().as_ref(),
142+
PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
143+
)
144+
.unwrap();
145+
146+
assert_eq!(
147+
taken.scalar_at(0).unwrap(),
148+
Scalar::new(
149+
DType::Primitive(PType::I32, Nullability::Nullable),
150+
ScalarValue::from(1i32)
151+
)
152+
);
153+
assert_eq!(
154+
taken.scalar_at(1).unwrap(),
155+
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
156+
);
157+
}
129158
}

vortex-array/src/arrays/struct_/compute/mod.rs

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,43 @@ mod filter;
33
mod mask;
44

55
use itertools::Itertools;
6+
use vortex_dtype::Nullability::NonNullable;
67
use vortex_error::VortexResult;
8+
use vortex_scalar::Scalar;
79

810
use crate::arrays::StructVTable;
911
use crate::arrays::struct_::StructArray;
1012
use crate::compute::{
1113
IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
12-
MinMaxResult, TakeKernel, TakeKernelAdapter, is_constant_opts, take,
14+
MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
1315
};
16+
use crate::validity::Validity;
1417
use crate::vtable::ValidityHelper;
1518
use crate::{Array, ArrayRef, IntoArray, register_kernel};
1619

1720
impl TakeKernel for StructVTable {
1821
fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22+
// If the struct array is empty then the indices must be all null, otherwise it will access
23+
// an out of bounds element
24+
if array.is_empty() {
25+
return StructArray::try_new_with_dtype(
26+
array.fields().to_vec(),
27+
array.struct_fields().clone(),
28+
indices.len(),
29+
Validity::AllInvalid,
30+
)
31+
.map(StructArray::into_array);
32+
}
33+
// The validity is applied to the struct validity,
34+
let inner_indices = &fill_null(
35+
indices,
36+
&Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
37+
)?;
1938
StructArray::try_new_with_dtype(
2039
array
2140
.fields()
2241
.iter()
23-
.map(|field| take(field, indices))
42+
.map(|field| take(field, inner_indices))
2443
.try_collect()?,
2544
array.struct_fields().clone(),
2645
indices.len(),
@@ -71,13 +90,15 @@ register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
7190
mod tests {
7291
use std::sync::Arc;
7392

93+
use Nullability::{NonNullable, Nullable};
7494
use vortex_buffer::buffer;
7595
use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
7696
use vortex_mask::Mask;
97+
use vortex_scalar::Scalar;
7798

7899
use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
79100
use crate::compute::conformance::mask::test_mask;
80-
use crate::compute::{cast, filter};
101+
use crate::compute::{cast, filter, take};
81102
use crate::validity::Validity;
82103
use crate::{Array, IntoArray as _};
83104

@@ -92,6 +113,52 @@ mod tests {
92113
assert_eq!(filtered.len(), 5);
93114
}
94115

116+
#[test]
117+
fn take_empty_struct() {
118+
let struct_arr =
119+
StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
120+
let indices = PrimitiveArray::from_option_iter([Some(1), None]);
121+
let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
122+
assert_eq!(taken.len(), 2);
123+
124+
assert_eq!(
125+
taken.scalar_at(0).unwrap(),
126+
Scalar::struct_(
127+
DType::Struct(Arc::new(StructFields::new([].into(), vec![])), Nullable),
128+
vec![]
129+
)
130+
);
131+
assert_eq!(
132+
taken.scalar_at(1).unwrap(),
133+
Scalar::null(DType::Struct(
134+
Arc::new(StructFields::new([].into(), vec![])),
135+
Nullable
136+
))
137+
);
138+
}
139+
140+
#[test]
141+
fn take_field_struct() {
142+
let struct_arr =
143+
StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
144+
.unwrap();
145+
let indices = PrimitiveArray::from_option_iter([Some(1), None]);
146+
let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
147+
assert_eq!(taken.len(), 2);
148+
149+
assert_eq!(
150+
taken.scalar_at(0).unwrap(),
151+
Scalar::struct_(
152+
struct_arr.dtype().union_nullability(Nullable),
153+
vec![Scalar::primitive(1, NonNullable)],
154+
)
155+
);
156+
assert_eq!(
157+
taken.scalar_at(1).unwrap(),
158+
Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
159+
);
160+
}
161+
95162
#[test]
96163
fn filter_empty_struct_with_empty_filter() {
97164
let struct_arr =
@@ -114,7 +181,7 @@ mod tests {
114181
let xs = buffer![0i64, 1, 2, 3, 4].into_array();
115182
let ys = VarBinArray::from_iter(
116183
[Some("a"), Some("b"), None, Some("d"), None],
117-
DType::Utf8(Nullability::Nullable),
184+
DType::Utf8(Nullable),
118185
)
119186
.into_array();
120187
let zs =
@@ -148,17 +215,13 @@ mod tests {
148215
let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
149216
.unwrap()
150217
.into_array();
151-
let non_nullable_dtype = DType::Struct(
152-
Arc::from(StructFields::new([].into(), vec![])),
153-
Nullability::NonNullable,
154-
);
218+
let non_nullable_dtype =
219+
DType::Struct(Arc::from(StructFields::new([].into(), vec![])), NonNullable);
155220
let casted = cast(&array, &non_nullable_dtype).unwrap();
156221
assert_eq!(casted.dtype(), &non_nullable_dtype);
157222

158-
let nullable_dtype = DType::Struct(
159-
Arc::from(StructFields::new([].into(), vec![])),
160-
Nullability::Nullable,
161-
);
223+
let nullable_dtype =
224+
DType::Struct(Arc::from(StructFields::new([].into(), vec![])), Nullable);
162225
let casted = cast(&array, &nullable_dtype).unwrap();
163226
assert_eq!(casted.dtype(), &nullable_dtype);
164227
}
@@ -177,7 +240,7 @@ mod tests {
177240
)
178241
.unwrap();
179242

180-
let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
243+
let tu8 = DType::Primitive(PType::U8, NonNullable);
181244

182245
let result = cast(
183246
array.as_ref(),
@@ -186,7 +249,7 @@ mod tests {
186249
FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
187250
vec![tu8.clone(), tu8.clone(), tu8],
188251
)),
189-
Nullability::NonNullable,
252+
NonNullable,
190253
),
191254
);
192255
assert!(
@@ -201,10 +264,7 @@ mod tests {
201264
#[test]
202265
fn test_cast_complex_struct() {
203266
let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
204-
let ys = VarBinArray::from_vec(
205-
vec!["a", "b", "c", "d", "e"],
206-
DType::Utf8(Nullability::Nullable),
207-
);
267+
let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
208268
let zs = BoolArray::new(
209269
BooleanBuffer::from_iter([true, true, false, false, true]),
210270
Validity::AllValid,
@@ -241,17 +301,17 @@ mod tests {
241301
Arc::from(StructFields::new(
242302
["left".into(), "right".into()].into(),
243303
vec![
244-
DType::Primitive(PType::I64, Nullability::NonNullable),
245-
DType::Primitive(PType::I64, Nullability::Nullable),
304+
DType::Primitive(PType::I64, NonNullable),
305+
DType::Primitive(PType::I64, Nullable),
246306
],
247307
)),
248-
Nullability::Nullable,
308+
Nullable,
249309
),
250-
DType::Utf8(Nullability::Nullable),
251-
DType::Bool(Nullability::Nullable),
310+
DType::Utf8(Nullable),
311+
DType::Bool(Nullable),
252312
],
253313
)),
254-
Nullability::Nullable,
314+
Nullable,
255315
);
256316
let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
257317
assert_eq!(casted.dtype(), &non_null_xs_right);
@@ -264,17 +324,17 @@ mod tests {
264324
Arc::from(StructFields::new(
265325
["left".into(), "right".into()].into(),
266326
vec![
267-
DType::Primitive(PType::I64, Nullability::Nullable),
268-
DType::Primitive(PType::I64, Nullability::Nullable),
327+
DType::Primitive(PType::I64, Nullable),
328+
DType::Primitive(PType::I64, Nullable),
269329
],
270330
)),
271-
Nullability::NonNullable,
331+
NonNullable,
272332
),
273-
DType::Utf8(Nullability::Nullable),
274-
DType::Bool(Nullability::Nullable),
333+
DType::Utf8(Nullable),
334+
DType::Bool(Nullable),
275335
],
276336
)),
277-
Nullability::Nullable,
337+
Nullable,
278338
);
279339
let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
280340
assert_eq!(casted.dtype(), &non_null_xs);

0 commit comments

Comments
 (0)