Skip to content

Commit 33e8cf0

Browse files
authored
Batch execution API (#5242)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent bdfa4c2 commit 33e8cf0

File tree

13 files changed

+208
-77
lines changed

13 files changed

+208
-77
lines changed

encodings/sequence/src/operator.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ mod tests {
188188
.unwrap()
189189
.into_array();
190190

191-
let selection = bitbuffer![1 0 1 0 1].into_array();
191+
let selection = bitbuffer![1 0 1 0 1].into();
192192
let result = seq
193-
.execute_with_selection(Some(&selection))
193+
.execute_with_selection(&selection)
194194
.unwrap()
195195
.into_primitive()
196196
.into_i32();
@@ -208,9 +208,9 @@ mod tests {
208208
.unwrap()
209209
.into_array();
210210

211-
let selection = bitbuffer![1 1 0 0 0].into_array();
211+
let selection = bitbuffer![1 1 0 0 0].into();
212212
let result = seq
213-
.execute_with_selection(Some(&selection))
213+
.execute_with_selection(&selection)
214214
.unwrap()
215215
.into_primitive()
216216
.into_i64();
@@ -225,9 +225,9 @@ mod tests {
225225
.unwrap()
226226
.into_array();
227227

228-
let selection = bitbuffer![0 0 1 1].into_array();
228+
let selection = bitbuffer![0 0 1 1].into();
229229
let result = seq
230-
.execute_with_selection(Some(&selection))
230+
.execute_with_selection(&selection)
231231
.unwrap()
232232
.into_primitive()
233233
.into_u64();
@@ -245,8 +245,8 @@ mod tests {
245245
.unwrap()
246246
.into_array();
247247

248-
let selection = bitbuffer![0 0 0 0].into_array();
249-
let result = seq.execute_with_selection(Some(&selection)).unwrap();
248+
let selection = bitbuffer![0 0 0 0].into();
249+
let result = seq.execute_with_selection(&selection).unwrap();
250250
assert!(result.is_empty())
251251
}
252252

@@ -257,9 +257,9 @@ mod tests {
257257
.unwrap()
258258
.into_array();
259259

260-
let selection = bitbuffer![1 1 1 1].into_array();
260+
let selection = bitbuffer![1 1 1 1].into();
261261
let result = seq
262-
.execute_with_selection(Some(&selection))
262+
.execute_with_selection(&selection)
263263
.unwrap()
264264
.into_primitive()
265265
.into_i16();
@@ -277,9 +277,9 @@ mod tests {
277277
.unwrap()
278278
.into_array();
279279

280-
let selection = bitbuffer![1 0 0 1 0 1].into_array();
280+
let selection = bitbuffer![1 0 0 1 0 1].into();
281281
let result = seq
282-
.execute_with_selection(Some(&selection))
282+
.execute_with_selection(&selection)
283283
.unwrap()
284284
.into_primitive()
285285
.into_i32();

vortex-array/src/array/operator.rs

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
use std::sync::Arc;
55

6-
use vortex_dtype::DType;
7-
use vortex_error::{VortexResult, vortex_bail};
8-
use vortex_vector::Vector;
6+
use vortex_error::{VortexResult, vortex_panic};
7+
use vortex_mask::Mask;
8+
use vortex_vector::{Vector, VectorOps, vector_matches_dtype};
99

10-
use crate::execution::{BatchKernelRef, BindCtx};
10+
use crate::execution::{BatchKernelRef, BindCtx, DummyExecutionCtx, ExecutionCtx};
1111
use crate::vtable::{OperatorVTable, VTable};
1212
use crate::{Array, ArrayAdapter, ArrayRef};
1313

@@ -16,13 +16,13 @@ use crate::{Array, ArrayAdapter, ArrayRef};
1616
/// Note: the public functions such as "execute" should move onto the main `Array` trait when
1717
/// operators is stabilized. The other functions should remain on a `pub(crate)` trait.
1818
pub trait ArrayOperator: 'static + Send + Sync {
19-
/// Execute the array producing a canonical vector.
20-
fn execute(&self) -> VortexResult<Vector> {
21-
self.execute_with_selection(None)
22-
}
23-
24-
/// Execute the array with a selection mask, producing a canonical vector.
25-
fn execute_with_selection(&self, selection: Option<&ArrayRef>) -> VortexResult<Vector>;
19+
/// Execute the array's batch kernel with the given selection mask.
20+
///
21+
/// # Panics
22+
///
23+
/// If the mask length does not match the array length.
24+
/// If the array's implementation returns an invalid vector (wrong length, wrong type, etc).
25+
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector>;
2626

2727
/// Optimize the array by running the optimization rules.
2828
fn reduce_children(&self) -> VortexResult<Option<ArrayRef>>;
@@ -39,8 +39,8 @@ pub trait ArrayOperator: 'static + Send + Sync {
3939
}
4040

4141
impl ArrayOperator for Arc<dyn Array> {
42-
fn execute_with_selection(&self, selection: Option<&ArrayRef>) -> VortexResult<Vector> {
43-
self.as_ref().execute_with_selection(selection)
42+
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
43+
self.as_ref().execute_batch(selection, ctx)
4444
}
4545

4646
fn reduce_children(&self) -> VortexResult<Option<ArrayRef>> {
@@ -61,23 +61,31 @@ impl ArrayOperator for Arc<dyn Array> {
6161
}
6262

6363
impl<V: VTable> ArrayOperator for ArrayAdapter<V> {
64-
fn execute_with_selection(&self, selection: Option<&ArrayRef>) -> VortexResult<Vector> {
65-
if let Some(selection) = selection.as_ref() {
66-
if !matches!(selection.dtype(), DType::Bool(_)) {
67-
vortex_bail!(
68-
"Selection array must be of boolean type, got {}",
69-
selection.dtype()
70-
);
71-
}
72-
if selection.len() != self.len() {
73-
vortex_bail!(
74-
"Selection array length {} does not match array length {}",
75-
selection.len(),
76-
self.len()
64+
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
65+
let vector =
66+
<V::OperatorVTable as OperatorVTable<V>>::execute_batch(&self.0, selection, ctx)?;
67+
68+
// Such a cheap check that we run it always. More expensive DType checks live in
69+
// debug_assertions.
70+
assert_eq!(
71+
vector.len(),
72+
selection.true_count(),
73+
"Batch execution returned vector of incorrect length"
74+
);
75+
76+
#[cfg(debug_assertions)]
77+
{
78+
// Checks for correct type and nullability.
79+
if !vector_matches_dtype(&vector, self.dtype()) {
80+
vortex_panic!(
81+
"Returned vector {:?} does not match expected dtype {}",
82+
vector,
83+
self.dtype()
7784
);
7885
}
7986
}
80-
self.bind(selection, &mut ())?.execute()
87+
88+
Ok(vector)
8189
}
8290

8391
fn reduce_children(&self) -> VortexResult<Option<ArrayRef>> {
@@ -107,3 +115,14 @@ impl BindCtx for () {
107115
array.bind(selection, self)
108116
}
109117
}
118+
119+
impl dyn Array + '_ {
120+
pub fn execute(&self) -> VortexResult<Vector> {
121+
self.execute_batch(&Mask::new_true(self.len()), &mut DummyExecutionCtx)
122+
}
123+
124+
pub fn execute_with_selection(&self, mask: &Mask) -> VortexResult<Vector> {
125+
assert_eq!(self.len(), mask.len());
126+
self.execute_batch(mask, &mut DummyExecutionCtx)
127+
}
128+
}

vortex-array/src/arrays/listview/vtable/operator.rs

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ impl OperatorVTable<ListViewVTable> for ListViewVTable {
4343

4444
#[cfg(test)]
4545
mod tests {
46-
use std::sync::Arc;
47-
4846
use vortex_dtype::PTypeDowncast;
4947
use vortex_mask::Mask;
5048
use vortex_vector::VectorOps;
@@ -53,7 +51,7 @@ mod tests {
5351
use crate::arrays::listview::tests::common::{
5452
create_basic_listview, create_nullable_listview, create_overlapping_listview,
5553
};
56-
use crate::arrays::{BoolArray, ListViewArray, PrimitiveArray};
54+
use crate::arrays::{ListViewArray, PrimitiveArray};
5755
use crate::validity::Validity;
5856

5957
#[test]
@@ -99,12 +97,10 @@ mod tests {
9997
let listview = ListViewArray::new(elements, offsets, sizes, Validity::AllValid);
10098

10199
// Create selection mask: [true, false, true, false, true, false].
102-
let selection = BoolArray::from_iter([true, false, true, false, true, false]).into_array();
100+
let selection = Mask::from_iter([true, false, true, false, true, false]);
103101

104102
// Execute with selection.
105-
let result = listview
106-
.execute_with_selection(Some(&Arc::new(selection)))
107-
.unwrap();
103+
let result = listview.execute_with_selection(&selection).unwrap();
108104

109105
// Verify filtered length (3 lists selected).
110106
assert_eq!(result.len(), 3);
@@ -133,12 +129,10 @@ mod tests {
133129
let listview = create_nullable_listview();
134130

135131
// Create selection mask: [true, true, false].
136-
let selection = BoolArray::from_iter([true, true, false]).into_array();
132+
let selection = Mask::from_iter([true, true, false]);
137133

138134
// Execute with selection.
139-
let result = listview
140-
.execute_with_selection(Some(&Arc::new(selection)))
141-
.unwrap();
135+
let result = listview.execute_with_selection(&selection).unwrap();
142136

143137
// Verify filtered length (2 lists selected, including the null).
144138
assert_eq!(result.len(), 2);
@@ -168,12 +162,10 @@ mod tests {
168162
let listview = create_overlapping_listview();
169163

170164
// Create selection mask: [true, false, true, true, false].
171-
let selection = BoolArray::from_iter([true, false, true, true, false]).into_array();
165+
let selection = Mask::from_iter([true, false, true, true, false]);
172166

173167
// Execute with selection.
174-
let result = listview
175-
.execute_with_selection(Some(&Arc::new(selection)))
176-
.unwrap();
168+
let result = listview.execute_with_selection(&selection).unwrap();
177169

178170
// Verify filtered length (3 lists selected).
179171
assert_eq!(result.len(), 3);

vortex-array/src/arrays/struct_/vtable/operator.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ impl OperatorVTable<StructVTable> for StructVTable {
4141

4242
#[cfg(test)]
4343
mod tests {
44-
use std::sync::Arc;
45-
4644
use vortex_dtype::{FieldNames, PTypeDowncast};
45+
use vortex_mask::Mask;
4746
use vortex_vector::VectorOps;
4847

4948
use crate::IntoArray;
@@ -98,12 +97,10 @@ mod tests {
9897
.unwrap();
9998

10099
// Create a selection mask that selects indices 0, 2, 4 (alternating pattern).
101-
let selection = BoolArray::from_iter([true, false, true, false, true, false]).into_array();
100+
let selection = Mask::from_iter([true, false, true, false, true, false]);
102101

103102
// Execute with selection mask.
104-
let result = struct_array
105-
.execute_with_selection(Some(&Arc::new(selection)))
106-
.unwrap();
103+
let result = struct_array.execute_with_selection(&selection).unwrap();
107104

108105
// Verify the result has the filtered length.
109106
assert_eq!(result.len(), 3);
@@ -152,12 +149,10 @@ mod tests {
152149
.unwrap();
153150

154151
// Create a selection mask that selects indices 0, 1, 2, 4, 5.
155-
let selection = BoolArray::from_iter([true, true, true, false, true, true]).into_array();
152+
let selection = Mask::from_iter([true, true, true, false, true, true]);
156153

157154
// Execute with selection mask.
158-
let result = struct_array
159-
.execute_with_selection(Some(&Arc::new(selection)))
160-
.unwrap();
155+
let result = struct_array.execute_with_selection(&selection).unwrap();
161156

162157
assert_eq!(result.len(), 5);
163158

vortex-array/src/compute/arrays/arithmetic.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ mod tests {
347347

348348
use crate::arrays::PrimitiveArray;
349349
use crate::compute::arrays::arithmetic::{ArithmeticArray, ArithmeticOperator};
350-
use crate::{ArrayOperator, ArrayRef, IntoArray};
350+
use crate::{ArrayRef, IntoArray};
351351

352352
fn add(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
353353
ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Add).into_array()
@@ -418,10 +418,8 @@ mod tests {
418418
let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
419419
let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
420420

421-
let selection = bitbuffer![1 0 1].into_array();
422-
423421
let result = add(lhs, rhs)
424-
.execute_with_selection(Some(&selection))
422+
.execute_with_selection(&bitbuffer![1 0 1].into())
425423
.unwrap()
426424
.into_primitive()
427425
.downcast::<u32>();

vortex-array/src/compute/arrays/get_item.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ mod tests {
147147
use vortex_dtype::{FieldNames, Nullability, PTypeDowncast};
148148
use vortex_vector::VectorOps;
149149

150+
use crate::IntoArray;
150151
use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
151152
use crate::compute::arrays::get_item::GetItemArray;
152153
use crate::validity::Validity;
153-
use crate::{ArrayOperator, IntoArray};
154154

155155
#[test]
156156
fn test_get_item_basic() {
@@ -233,9 +233,9 @@ mod tests {
233233
.into_array();
234234

235235
// Apply selection mask [1 0 1 0 1 0] => select indices 0, 2, 4
236-
let selection = bitbuffer![1 0 1 0 1 0].into_array();
236+
let selection = bitbuffer![1 0 1 0 1 0].into();
237237
let result = get_item
238-
.execute_with_selection(Some(&selection))
238+
.execute_with_selection(&selection)
239239
.unwrap()
240240
.into_primitive()
241241
.into_i32();

vortex-array/src/compute/arrays/logical.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ mod tests {
213213
use vortex_buffer::bitbuffer;
214214

215215
use crate::compute::arrays::logical::{LogicalArray, LogicalOperator};
216-
use crate::{ArrayOperator, ArrayRef, IntoArray};
216+
use crate::{ArrayRef, IntoArray};
217217

218218
fn and_(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
219219
LogicalArray::new(lhs, rhs, LogicalOperator::And).into_array()
@@ -232,10 +232,10 @@ mod tests {
232232
let lhs = bitbuffer![0 1 0].into_array();
233233
let rhs = bitbuffer![0 1 1].into_array();
234234

235-
let selection = bitbuffer![0 1 1].into_array();
235+
let selection = bitbuffer![0 1 1].into();
236236

237237
let result = and_(lhs, rhs)
238-
.execute_with_selection(Some(&selection))
238+
.execute_with_selection(&selection)
239239
.unwrap()
240240
.into_bool();
241241
assert_eq!(result.bits(), &bitbuffer![1 0]);

vortex-array/src/execution/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,18 @@ mod validity;
77

88
pub use batch::*;
99
pub use mask::*;
10+
11+
/// Execution context for batch array compute.
12+
// NOTE(ngates): This context will eventually hold cached resources for execution, such as CSE
13+
// nodes, and may well eventually support a type-map interface for arrays to stash arbitrary
14+
// execution-related data.
15+
pub trait ExecutionCtx: private::Sealed {}
16+
17+
/// A crate-internal dummy execution context.
18+
pub(crate) struct DummyExecutionCtx;
19+
impl ExecutionCtx for DummyExecutionCtx {}
20+
21+
mod private {
22+
pub trait Sealed {}
23+
impl Sealed for super::DummyExecutionCtx {}
24+
}

vortex-array/src/optimizer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ mod tests {
6363
use vortex_dtype::PTypeDowncast;
6464
use vortex_vector::VectorOps;
6565

66+
use crate::IntoArray;
6667
use crate::arrays::{BoolArray, MaskedArray, PrimitiveArray};
6768
use crate::validity::Validity;
68-
use crate::{ArrayOperator, IntoArray};
6969

7070
#[test]
7171
fn test_masked_pushdown() {

0 commit comments

Comments
 (0)