Skip to content

Commit 956f5c2

Browse files
authored
Fix zero-copy pipeline input (#5324)
The into_mut was triggering a copy before this --------- Signed-off-by: Nicholas Gates <[email protected]>
1 parent 3b9f233 commit 956f5c2

File tree

4 files changed

+62
-3
lines changed

4 files changed

+62
-3
lines changed

vortex-array/src/pipeline/driver/bind.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::pipeline::driver::{Node, NodeKind};
1111
use crate::pipeline::{BindContext, Kernel, VectorId};
1212

1313
pub(crate) fn bind_kernels(
14-
dag: &[Node],
14+
dag: Vec<Node>,
1515
allocation_plan: &VectorAllocation,
1616
mut all_batch_inputs: Vec<Option<Vector>>,
1717
) -> VortexResult<Vec<Box<dyn Kernel>>> {
@@ -40,6 +40,10 @@ pub(crate) fn bind_kernels(
4040
assert_eq!(node.batch_inputs.len(), 1);
4141
let batch_id = node.batch_inputs[0];
4242

43+
// Release ownership of the array before trying to call into_mut on the vector.
44+
// This is in case the vector was constructed zero-copy from the array's data.
45+
drop(node.array);
46+
4347
let batch = batch_inputs[batch_id]
4448
.take()
4549
.vortex_expect("Batch input vector has already been consumed")

vortex-array/src/pipeline/driver/input.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@ impl Kernel for InputKernel {
7373

7474
#[cfg(test)]
7575
mod test {
76-
use vortex_buffer::{bitbuffer, buffer};
76+
use vortex_buffer::{BitBuffer, bitbuffer, buffer};
7777
use vortex_dtype::PTypeDowncastExt;
7878
use vortex_mask::Mask;
7979

80+
use crate::arrays::{BoolArray, PrimitiveArray};
8081
use crate::pipeline::driver::PipelineDriver;
82+
use crate::validity::Validity;
8183
use crate::{Array, ArrayOperator, IntoArray};
8284

8385
#[test]
@@ -113,4 +115,38 @@ mod test {
113115
.downcast::<u32>();
114116
assert_eq!(vector.elements().as_ref(), &[0u32, 2, 4]);
115117
}
118+
119+
/// Ensures that we can feed an input into a pipeline with zero-copy.
120+
/// This can require careful book keeping to make sure we don't hold references to arrays
121+
/// around longer than necessary.
122+
#[test]
123+
fn test_pipeline_input_zero_copy() {
124+
let elements = buffer![123u32; 8000];
125+
let elements_ptr = elements.as_ptr();
126+
let validity = BitBuffer::from_iter((0..8000).map(|i| i % 2 == 0));
127+
let validity_ptr = validity.inner().as_ptr();
128+
129+
let array = PrimitiveArray::new(
130+
elements,
131+
Validity::Array(BoolArray::from(validity).into_array()),
132+
)
133+
.into_array();
134+
assert!(
135+
array.as_pipelined().is_none(),
136+
"We're explicitly testing non-pipelined arrays to trigger the input case"
137+
);
138+
139+
let selection = Mask::new_true(array.len());
140+
let vector = PipelineDriver::new(array)
141+
.execute(&selection)
142+
.unwrap()
143+
.into_primitive()
144+
.downcast::<u32>();
145+
146+
let (vector_elements, vector_validity) = vector.into_parts();
147+
let vector_validity = vector_validity.into_bit_buffer().into_inner();
148+
149+
assert_eq!(vector_elements.as_ptr(), elements_ptr);
150+
assert_eq!(vector_validity.as_ptr(), validity_ptr);
151+
}
116152
}

vortex-array/src/pipeline/driver/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ impl PipelineDriver {
209209
let allocation_plan = allocate_vectors(&self.dag, &exec_order)?;
210210

211211
// Bind each node in the DAG to create its kernel
212-
let kernels = bind_kernels(&self.dag, &allocation_plan, batch_inputs)?;
212+
let kernels = bind_kernels(self.dag, &allocation_plan, batch_inputs)?;
213213

214214
// Construct the kernel execution context
215215
let ctx = KernelCtx::new(allocation_plan.vectors);

vortex-mask/src/lib.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,19 @@ impl Mask {
441441
}
442442
}
443443

444+
/// Return a boolean buffer representation of the mask, allocating new buffers for all-true
445+
/// and all-false variants.
446+
#[inline]
447+
pub fn into_bit_buffer(self) -> BitBuffer {
448+
match self {
449+
Self::AllTrue(l) => BitBuffer::new_set(l),
450+
Self::AllFalse(l) => BitBuffer::new_unset(l),
451+
Self::Values(values) => Arc::try_unwrap(values)
452+
.map(|v| v.into_bit_buffer())
453+
.unwrap_or_else(|v| v.bit_buffer().clone()),
454+
}
455+
}
456+
444457
/// Return the indices representation of the mask.
445458
#[inline]
446459
pub fn indices(&self) -> AllOr<&[usize]> {
@@ -598,6 +611,12 @@ impl MaskValues {
598611
&self.buffer
599612
}
600613

614+
/// Returns the boolean buffer representation of the mask.
615+
#[inline]
616+
pub fn into_bit_buffer(self) -> BitBuffer {
617+
self.buffer
618+
}
619+
601620
/// Returns the boolean value at a given index.
602621
#[inline]
603622
pub fn value(&self, index: usize) -> bool {

0 commit comments

Comments
 (0)