Skip to content

Commit cb46737

Browse files
authored
Remove selection mask from batch vtable (#5424)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 815a4e2 commit cb46737

File tree

25 files changed

+349
-90
lines changed

25 files changed

+349
-90
lines changed

encodings/sequence/src/array.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
use std::hash::Hash;
55
use std::ops::Range;
66

7+
use num_traits::One;
78
use num_traits::cast::FromPrimitive;
89
use vortex_array::arrays::PrimitiveArray;
10+
use vortex_array::execution::ExecutionCtx;
911
use vortex_array::serde::ArrayChildren;
1012
use vortex_array::stats::{ArrayStats, StatsSetRef};
1113
use vortex_array::vtable::{
@@ -24,6 +26,8 @@ use vortex_dtype::{
2426
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2527
use vortex_mask::Mask;
2628
use vortex_scalar::{PValue, Scalar, ScalarValue};
29+
use vortex_vector::Vector;
30+
use vortex_vector::primitive::PVector;
2731

2832
vtable!(Sequence);
2933

@@ -242,6 +246,26 @@ impl VTable for SequenceVTable {
242246
len,
243247
))
244248
}
249+
250+
fn execute(array: &Self::Array, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
251+
Ok(match_each_native_ptype!(array.ptype(), |P| {
252+
let base = array.base().cast::<P>();
253+
let multiplier = array.multiplier().cast::<P>();
254+
255+
let values = if multiplier == <P>::one() {
256+
BufferMut::from_iter(
257+
(0..array.len()).map(|i| base + <P>::from_usize(i).vortex_expect("must fit")),
258+
)
259+
} else {
260+
BufferMut::from_iter(
261+
(0..array.len())
262+
.map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
263+
)
264+
};
265+
266+
PVector::<P>::new(values.freeze(), Mask::new_true(array.len())).into()
267+
}))
268+
}
245269
}
246270

247271
impl ArrayVTable<SequenceVTable> for SequenceVTable {

vortex-array/src/array/operator.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
use std::sync::Arc;
55

6+
use vortex_compute::filter::Filter;
67
use vortex_error::{VortexResult, vortex_panic};
78
use vortex_mask::Mask;
8-
use vortex_vector::{Vector, VectorOps, vector_matches_dtype};
9+
use vortex_vector::{Vector, vector_matches_dtype};
910

1011
use crate::execution::{BatchKernelRef, BindCtx, DummyExecutionCtx, ExecutionCtx};
1112
use crate::pipeline::PipelinedNode;
@@ -24,7 +25,7 @@ pub trait ArrayOperator: 'static + Send + Sync {
2425
///
2526
/// If the mask length does not match the array length.
2627
/// If the array's implementation returns an invalid vector (wrong length, wrong type, etc.).
27-
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector>;
28+
fn execute_batch(&self, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector>;
2829

2930
/// Optimize the array by running the optimization rules.
3031
fn reduce(&self) -> VortexResult<Option<ArrayRef>>;
@@ -44,8 +45,8 @@ pub trait ArrayOperator: 'static + Send + Sync {
4445
}
4546

4647
impl ArrayOperator for Arc<dyn Array> {
47-
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
48-
self.as_ref().execute_batch(selection, ctx)
48+
fn execute_batch(&self, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
49+
self.as_ref().execute_batch(ctx)
4950
}
5051

5152
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
@@ -70,17 +71,8 @@ impl ArrayOperator for Arc<dyn Array> {
7071
}
7172

7273
impl<V: VTable> ArrayOperator for ArrayAdapter<V> {
73-
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
74-
let vector =
75-
<V::OperatorVTable as OperatorVTable<V>>::execute_batch(&self.0, selection, ctx)?;
76-
77-
// Such a cheap check that we run it always. More expensive DType checks live in
78-
// debug_assertions.
79-
assert_eq!(
80-
vector.len(),
81-
selection.true_count(),
82-
"Batch execution returned vector of incorrect length"
83-
);
74+
fn execute_batch(&self, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
75+
let vector = V::execute(&self.0, ctx)?;
8476

8577
if cfg!(debug_assertions) {
8678
// Checks for correct type and nullability.
@@ -130,17 +122,20 @@ impl BindCtx for () {
130122

131123
impl dyn Array + '_ {
132124
pub fn execute(&self) -> VortexResult<Vector> {
133-
self.execute_with_selection(&Mask::new_true(self.len()))
125+
// Check if the array is a pipeline node
126+
if self.as_pipelined().is_some() {
127+
return PipelineDriver::new(self.to_array()).execute(&Mask::new_true(self.len()));
128+
}
129+
self.execute_batch(&mut DummyExecutionCtx)
134130
}
135131

136132
pub fn execute_with_selection(&self, selection: &Mask) -> VortexResult<Vector> {
137-
assert_eq!(self.len(), selection.len());
138-
139133
// Check if the array is a pipeline node
140134
if self.as_pipelined().is_some() {
141135
return PipelineDriver::new(self.to_array()).execute(selection);
142136
}
143-
144-
self.execute_batch(selection, &mut DummyExecutionCtx)
137+
Ok(self
138+
.execute_batch(&mut DummyExecutionCtx)?
139+
.filter(selection))
145140
}
146141
}

vortex-array/src/arrays/bool/vtable/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
use vortex_buffer::ByteBuffer;
55
use vortex_dtype::DType;
66
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7+
use vortex_vector::Vector;
8+
use vortex_vector::bool::BoolVector;
79

810
use crate::arrays::BoolArray;
11+
use crate::execution::ExecutionCtx;
912
use crate::serde::ArrayChildren;
1013
use crate::validity::Validity;
1114
use crate::vtable::{NotSupported, VTable, ValidityVTableFromValidityHelper};
@@ -91,6 +94,10 @@ impl VTable for BoolVTable {
9194

9295
BoolArray::try_new(buffers[0].clone(), metadata.offset as usize, len, validity)
9396
}
97+
98+
fn execute(array: &Self::Array, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
99+
Ok(BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into())
100+
}
94101
}
95102

96103
#[derive(Clone, Debug)]

vortex-array/src/arrays/chunked/vtable/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ use itertools::Itertools;
55
use vortex_buffer::ByteBuffer;
66
use vortex_dtype::{DType, Nullability, PType};
77
use vortex_error::{VortexResult, vortex_bail, vortex_err};
8+
use vortex_vector::{Vector, VectorMut, VectorMutOps};
89

910
use crate::arrays::ChunkedArray;
11+
use crate::execution::ExecutionCtx;
1012
use crate::serde::ArrayChildren;
1113
use crate::vtable::{NotSupported, VTable};
12-
use crate::{EmptyMetadata, EncodingId, EncodingRef, ToCanonical, vtable};
14+
use crate::{ArrayOperator, EmptyMetadata, EncodingId, EncodingRef, ToCanonical, vtable};
1315

1416
mod array;
1517
mod canonical;
@@ -95,6 +97,15 @@ impl VTable for ChunkedVTable {
9597
// Each chunk was validated during deserialization to match the expected dtype.
9698
unsafe { Ok(ChunkedArray::new_unchecked(chunks, dtype.clone())) }
9799
}
100+
101+
fn execute(array: &Self::Array, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
102+
let mut vector = VectorMut::with_capacity(array.dtype(), 0);
103+
for chunk in array.chunks() {
104+
let chunk_vector = chunk.execute_batch(ctx)?;
105+
vector.extend_from_vector(&chunk_vector);
106+
}
107+
Ok(vector.freeze())
108+
}
98109
}
99110

100111
#[derive(Clone, Debug)]

vortex-array/src/arrays/constant/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,7 @@ pub use array::ConstantArray;
66

77
mod compute;
88

9+
mod vector;
910
mod vtable;
11+
1012
pub use vtable::{ConstantEncoding, ConstantVTable};

vortex-array/src/arrays/constant/vtable/operator.rs renamed to vortex-array/src/arrays/constant/vector.rs

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use vortex_dtype::{
55
DType, DecimalType, PrecisionScale, match_each_decimal_value_type, match_each_native_ptype,
66
};
7-
use vortex_error::{VortexExpect, VortexResult};
7+
use vortex_error::VortexExpect;
88
use vortex_scalar::{BinaryScalar, BoolScalar, DecimalScalar, PrimitiveScalar, Scalar, Utf8Scalar};
99
use vortex_vector::binaryview::{BinaryVectorMut, StringVectorMut};
1010
use vortex_vector::bool::BoolVectorMut;
@@ -13,29 +13,7 @@ use vortex_vector::null::NullVectorMut;
1313
use vortex_vector::primitive::{PVectorMut, PrimitiveVectorMut};
1414
use vortex_vector::{VectorMut, VectorMutOps};
1515

16-
use crate::ArrayRef;
17-
use crate::arrays::{ConstantArray, ConstantVTable};
18-
use crate::execution::{BatchKernelRef, BindCtx, kernel};
19-
use crate::vtable::OperatorVTable;
20-
21-
impl OperatorVTable<ConstantVTable> for ConstantVTable {
22-
fn bind(
23-
array: &ConstantArray,
24-
selection: Option<&ArrayRef>,
25-
ctx: &mut dyn BindCtx,
26-
) -> VortexResult<BatchKernelRef> {
27-
let mask = ctx.bind_selection(array.len, selection)?;
28-
let scalar = array.scalar().clone();
29-
30-
Ok(kernel(move || {
31-
// TODO(ngates): would be good to do a sum aggregation, rather than execution.
32-
let mask = mask.execute()?;
33-
Ok(to_vector(scalar, mask.true_count()).freeze())
34-
}))
35-
}
36-
}
37-
38-
fn to_vector(scalar: Scalar, len: usize) -> VectorMut {
16+
pub(super) fn to_vector(scalar: Scalar, len: usize) -> VectorMut {
3917
match scalar.dtype() {
4018
DType::Null => NullVectorMut::new(len).into(),
4119
DType::Bool(_) => to_vector_bool(scalar.as_bool(), len).into(),

vortex-array/src/arrays/constant/vtable/mod.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ use vortex_buffer::ByteBuffer;
55
use vortex_dtype::DType;
66
use vortex_error::{VortexResult, vortex_bail};
77
use vortex_scalar::{Scalar, ScalarValue};
8+
use vortex_vector::{Vector, VectorMutOps};
89

910
use crate::arrays::ConstantArray;
11+
use crate::arrays::constant::vector::to_vector;
12+
use crate::execution::ExecutionCtx;
1013
use crate::serde::ArrayChildren;
1114
use crate::vtable::{NotSupported, VTable};
1215
use crate::{EmptyMetadata, EncodingId, EncodingRef, vtable};
@@ -15,7 +18,6 @@ mod array;
1518
mod canonical;
1619
mod encode;
1720
mod operations;
18-
mod operator;
1921
mod validity;
2022
mod visitor;
2123

@@ -37,7 +39,7 @@ impl VTable for ConstantVTable {
3739
// TODO(ngates): implement a compute kernel for elementwise operations
3840
type ComputeVTable = NotSupported;
3941
type EncodeVTable = Self;
40-
type OperatorVTable = Self;
42+
type OperatorVTable = NotSupported;
4143

4244
fn id(_encoding: &Self::Encoding) -> EncodingId {
4345
EncodingId::new_ref("vortex.constant")
@@ -74,4 +76,8 @@ impl VTable for ConstantVTable {
7476
let scalar = Scalar::new(dtype.clone(), sv);
7577
Ok(ConstantArray::new(scalar, len))
7678
}
79+
80+
fn execute(array: &Self::Array, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
81+
Ok(to_vector(array.scalar().clone(), array.len()).freeze())
82+
}
7783
}

vortex-array/src/arrays/decimal/vtable/mod.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_buffer::{Alignment, Buffer, ByteBuffer};
5-
use vortex_dtype::{DType, NativeDecimalType, match_each_decimal_value_type};
5+
use vortex_dtype::{DType, NativeDecimalType, PrecisionScale, match_each_decimal_value_type};
66
use vortex_error::{VortexResult, vortex_bail, vortex_ensure};
77
use vortex_scalar::DecimalType;
8+
use vortex_vector::Vector;
9+
use vortex_vector::decimal::DVector;
810

911
use crate::arrays::DecimalArray;
12+
use crate::execution::ExecutionCtx;
1013
use crate::serde::ArrayChildren;
1114
use crate::validity::Validity;
1215
use crate::vtable::{NotSupported, VTable, ValidityVTableFromValidityHelper};
@@ -104,6 +107,19 @@ impl VTable for DecimalVTable {
104107
DecimalArray::try_new::<D>(buffer, *decimal_dtype, validity)
105108
})
106109
}
110+
111+
fn execute(array: &Self::Array, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
112+
match_each_decimal_value_type!(array.values_type(), |D| {
113+
Ok(unsafe {
114+
DVector::<D>::new_unchecked(
115+
PrecisionScale::new_unchecked(array.precision(), array.scale()),
116+
array.buffer::<D>(),
117+
array.validity_mask(),
118+
)
119+
}
120+
.into())
121+
})
122+
}
107123
}
108124

109125
#[derive(Clone, Debug)]

vortex-array/src/arrays/expr/vtable/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ use std::fmt::Debug;
1212
use vortex_buffer::ByteBuffer;
1313
use vortex_dtype::DType;
1414
use vortex_error::{VortexResult, vortex_bail};
15+
use vortex_vector::Vector;
1516

1617
use crate::arrays::expr::ExprArray;
18+
use crate::execution::ExecutionCtx;
1719
use crate::expr::Expression;
1820
use crate::serde::ArrayChildren;
1921
use crate::vtable::{NotSupported, VTable};
20-
use crate::{EncodingId, EncodingRef, vtable};
22+
use crate::{Array, ArrayOperator, EncodingId, EncodingRef, vtable};
2123

2224
vtable!(Expr);
2325

@@ -76,6 +78,11 @@ impl VTable for ExprVTable {
7678

7779
ExprArray::try_new(child, expr.clone(), dtype.clone())
7880
}
81+
82+
fn execute(array: &Self::Array, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
83+
let scope = array.child().execute_batch(ctx)?;
84+
array.expr().execute(&scope, array.child().dtype())
85+
}
7986
}
8087

8188
pub struct ExprArrayMetadata((Expression, DType));

vortex-array/src/arrays/extension/vtable/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ mod visitor;
1111
use vortex_buffer::ByteBuffer;
1212
use vortex_dtype::DType;
1313
use vortex_error::{VortexResult, vortex_bail};
14+
use vortex_vector::Vector;
1415

1516
use crate::arrays::extension::ExtensionArray;
17+
use crate::execution::ExecutionCtx;
1618
use crate::serde::ArrayChildren;
1719
use crate::vtable::{NotSupported, VTable, ValidityVTableFromChild};
18-
use crate::{EmptyMetadata, EncodingId, EncodingRef, vtable};
20+
use crate::{ArrayOperator, EmptyMetadata, EncodingId, EncodingRef, vtable};
1921

2022
vtable!(Extension);
2123

@@ -70,6 +72,10 @@ impl VTable for ExtensionVTable {
7072
let storage = children.get(0, ext_dtype.storage_dtype(), len)?;
7173
Ok(ExtensionArray::new(ext_dtype.clone(), storage))
7274
}
75+
76+
fn execute(array: &Self::Array, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
77+
array.storage().execute_batch(ctx)
78+
}
7379
}
7480

7581
#[derive(Clone, Debug)]

0 commit comments

Comments
 (0)