Skip to content

Commit 1937030

Browse files
committed
Filter runend kernel
Signed-off-by: Nicholas Gates <[email protected]>
2 parents d1db0bd + ad1de1e commit 1937030

File tree

5 files changed

+138
-56
lines changed

5 files changed

+138
-56
lines changed

encodings/sparse/src/canonical.rs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use vortex_dtype::StructFields;
4141
use vortex_dtype::match_each_decimal_value_type;
4242
use vortex_dtype::match_each_integer_ptype;
4343
use vortex_dtype::match_each_native_ptype;
44+
use vortex_dtype::match_smallest_offset_type;
4445
use vortex_error::VortexError;
4546
use vortex_error::VortexExpect;
4647
use vortex_error::vortex_panic;
@@ -124,27 +125,6 @@ fn canonicalize_sparse_lists(
124125
values_dtype: Arc<DType>,
125126
nullability: Nullability,
126127
) -> Canonical {
127-
// TODO(connor): We should move this to `vortex-dtype` so that we can use this elsewhere.
128-
macro_rules! match_smallest_offset_type {
129-
($n_elements:expr, | $offset_type:ident | $body:block) => {{
130-
let n_elements = $n_elements;
131-
if n_elements <= u8::MAX as usize {
132-
type $offset_type = u8;
133-
$body
134-
} else if n_elements <= u16::MAX as usize {
135-
type $offset_type = u16;
136-
$body
137-
} else if n_elements <= u32::MAX as usize {
138-
type $offset_type = u32;
139-
$body
140-
} else {
141-
assert!(u64::try_from(n_elements).is_ok());
142-
type $offset_type = u64;
143-
$body
144-
}
145-
}};
146-
}
147-
148128
let resolved_patches = array.resolved_patches();
149129

150130
let indices = resolved_patches.indices().to_primitive();

vortex-array/src/arrays/constant/vtable/mod.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::fmt::Debug;
5+
46
use vortex_buffer::BufferHandle;
57
use vortex_dtype::DType;
68
use vortex_error::VortexResult;
79
use vortex_error::vortex_bail;
810
use vortex_error::vortex_ensure;
11+
use vortex_mask::Mask;
912
use vortex_scalar::Scalar;
1013
use vortex_scalar::ScalarValue;
1114
use vortex_vector::ScalarOps;
15+
use vortex_vector::Vector;
1216
use vortex_vector::VectorMutOps;
1317

1418
use crate::ArrayRef;
1519
use crate::EmptyMetadata;
1620
use crate::arrays::ConstantArray;
1721
use crate::kernel::BindCtx;
22+
use crate::kernel::Kernel;
1823
use crate::kernel::KernelRef;
19-
use crate::kernel::kernel;
24+
use crate::kernel::PushDownResult;
2025
use crate::serde::ArrayChildren;
2126
use crate::vtable;
2227
use crate::vtable::ArrayId;
@@ -89,9 +94,10 @@ impl VTable for ConstantVTable {
8994
}
9095

9196
fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult<KernelRef> {
92-
let scalar = array.scalar().to_vector_scalar();
93-
let len = array.len();
94-
Ok(kernel(move || Ok(scalar.clone().repeat(len).freeze())))
97+
Ok(Box::new(ConstantKernel {
98+
value: array.scalar.to_vector_scalar(),
99+
len: array.len,
100+
}))
95101
}
96102

97103
fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
@@ -103,3 +109,23 @@ impl VTable for ConstantVTable {
103109
Ok(())
104110
}
105111
}
112+
113+
#[derive(Debug)]
114+
struct ConstantKernel {
115+
value: vortex_vector::Scalar,
116+
len: usize,
117+
}
118+
119+
impl Kernel for ConstantKernel {
120+
fn execute(self: Box<Self>) -> VortexResult<Vector> {
121+
Ok(self.value.repeat(self.len).freeze())
122+
}
123+
124+
fn push_down_filter(self: Box<Self>, selection: &Mask) -> VortexResult<PushDownResult> {
125+
vortex_ensure!(self.len == selection.len());
126+
Ok(PushDownResult::Pushed(Box::new(ConstantKernel {
127+
value: self.value,
128+
len: selection.true_count(),
129+
})))
130+
}
131+
}

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use vortex_buffer::BitBufferMut;
55
use vortex_dtype::IntegerPType;
66
use vortex_dtype::Nullability;
77
use vortex_dtype::match_each_integer_ptype;
8+
use vortex_dtype::match_smallest_offset_type;
89
use vortex_error::VortexExpect;
910
use vortex_error::VortexResult;
10-
use vortex_error::vortex_panic;
1111
use vortex_mask::Mask;
1212

1313
use crate::Array;
@@ -34,27 +34,33 @@ use crate::vtable::ValidityHelper;
3434
/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
3535
/// non-contiguous indices would violate this requirement.
3636
impl TakeKernel for ListVTable {
37+
#[expect(clippy::cognitive_complexity)]
3738
fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
3839
let indices = indices.to_primitive();
3940
let offsets = array.offsets().to_primitive();
41+
// This is an over-approximation of the total number of elements in the resulting array.
42+
let total_approx = array.elements().len().saturating_mul(indices.len());
4043

4144
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
45+
let offsets_slice = offsets.as_slice::<O>();
4246
match_each_integer_ptype!(indices.ptype(), |I| {
43-
_take::<I, O>(
44-
array,
45-
offsets.as_slice::<O>(),
46-
&indices,
47-
array.validity_mask(),
48-
indices.validity_mask(),
49-
)
47+
match_smallest_offset_type!(total_approx, |OutputOffsetType| {
48+
_take::<I, O, OutputOffsetType>(
49+
array,
50+
offsets_slice,
51+
&indices,
52+
array.validity_mask(),
53+
indices.validity_mask(),
54+
)
55+
})
5056
})
5157
})
5258
}
5359
}
5460

5561
register_kernel!(TakeKernelAdapter(ListVTable).lift());
5662

57-
fn _take<I: IntegerPType, O: IntegerPType>(
63+
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
5864
array: &ListArray,
5965
offsets: &[O],
6066
indices_array: &PrimitiveArray,
@@ -64,7 +70,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6470
let indices: &[I] = indices_array.as_slice::<I>();
6571

6672
if !indices_validity_mask.all_true() || !data_validity.all_true() {
67-
return _take_nullable::<I, O>(
73+
return _take_nullable::<I, O, OutputOffsetType>(
6874
array,
6975
offsets,
7076
indices,
@@ -73,17 +79,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7379
);
7480
}
7581

76-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
82+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
83+
Nullability::NonNullable,
84+
indices.len(),
85+
);
7786
let mut elements_to_take =
7887
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
7988

80-
let mut current_offset = O::zero();
89+
let mut current_offset = OutputOffsetType::zero();
8190
new_offsets.append_zero();
8291

8392
for &data_idx in indices {
84-
let data_idx = data_idx
85-
.to_usize()
86-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
93+
let data_idx: usize = data_idx.as_();
8794

8895
let start = offsets[data_idx];
8996
let stop = offsets[data_idx + 1];
@@ -93,15 +100,15 @@ fn _take<I: IntegerPType, O: IntegerPType>(
93100
// We could convert start and end to usize, but that would impose a potentially
94101
// harder constraint - now we don't care if they fit into usize as long as their
95102
// difference does.
96-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
97-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
98-
});
103+
let additional: usize = (stop - start).as_();
99104

105+
// TODO(0ax1): optimize this
100106
elements_to_take.reserve_exact(additional);
101107
for i in 0..additional {
102108
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103109
}
104-
current_offset += stop - start;
110+
current_offset +=
111+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
105112
new_offsets.append_value(current_offset);
106113
}
107114

@@ -121,14 +128,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
121128
.to_array())
122129
}
123130

124-
fn _take_nullable<I: IntegerPType, O: IntegerPType>(
131+
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
125132
array: &ListArray,
126133
offsets: &[O],
127134
indices: &[I],
128135
data_validity: Mask,
129136
indices_validity: Mask,
130137
) -> VortexResult<ArrayRef> {
131-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
138+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
139+
Nullability::NonNullable,
140+
indices.len(),
141+
);
132142

133143
// This will be the indices we push down to the child array to call `take` with.
134144
//
@@ -140,7 +150,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
140150
let mut elements_to_take =
141151
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142152

143-
let mut current_offset = O::zero();
153+
let mut current_offset = OutputOffsetType::zero();
144154
new_offsets.append_zero();
145155

146156
// Set all bits to invalid and selectively set which values are valid.
@@ -153,9 +163,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
153163
continue;
154164
}
155165

156-
let data_idx = data_idx
157-
.to_usize()
158-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
166+
let data_idx: usize = data_idx.as_();
159167

160168
if !data_validity.value(data_idx) {
161169
new_offsets.append_value(current_offset);
@@ -167,15 +175,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
167175
let stop = offsets[data_idx + 1];
168176

169177
// See the note it the `take` on the reasoning
170-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
171-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
172-
});
178+
let additional: usize = (stop - start).as_();
173179

174180
elements_to_take.reserve_exact(additional);
175181
for i in 0..additional {
176182
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
177183
}
178-
current_offset += stop - start;
184+
current_offset +=
185+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
179186
new_offsets.append_value(current_offset);
180187
new_validity.set(idx);
181188
}
@@ -411,4 +418,46 @@ mod test {
411418
fn test_take_list_conformance(#[case] list: ListArray) {
412419
test_take_conformance(list.as_ref());
413420
}
421+
422+
#[test]
423+
fn test_u64_offset_accumulation_non_nullable() {
424+
let elements = buffer![0i32; 200].into_array();
425+
let offsets = buffer![0u8, 200].into_array();
426+
let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
427+
.unwrap()
428+
.to_array();
429+
430+
// Take the same large list twice - would overflow u8 but works with u64.
431+
let idx = buffer![0u8, 0].into_array();
432+
let result = take(&list, &idx).unwrap();
433+
434+
assert_eq!(result.len(), 2);
435+
436+
let result_view = result.to_listview();
437+
assert_eq!(result_view.len(), 2);
438+
assert!(result_view.is_valid(0));
439+
assert!(result_view.is_valid(1));
440+
}
441+
442+
#[test]
443+
fn test_u64_offset_accumulation_nullable() {
444+
let elements = buffer![0i32; 150].into_array();
445+
let offsets = buffer![0u8, 150, 150].into_array();
446+
let validity = BoolArray::from_iter(vec![true, false]).to_array();
447+
let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
448+
.unwrap()
449+
.to_array();
450+
451+
// Take the same large list twice - would overflow u8 but works with u64.
452+
let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
453+
let result = take(&list, &idx).unwrap();
454+
455+
assert_eq!(result.len(), 3);
456+
457+
let result_view = result.to_listview();
458+
assert_eq!(result_view.len(), 3);
459+
assert!(result_view.is_valid(0));
460+
assert!(result_view.is_invalid(1));
461+
assert!(result_view.is_valid(2));
462+
}
414463
}

vortex-dtype/src/ptype.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,28 @@ macro_rules! match_each_native_simd_ptype {
668668
}};
669669
}
670670

671+
/// Macro to match the smallest offset type for a given value
672+
#[macro_export]
673+
macro_rules! match_smallest_offset_type {
674+
($n_elements:expr, | $offset_type:ident | $body:block) => {{
675+
let n_elements = $n_elements;
676+
if n_elements <= u8::MAX as usize {
677+
type $offset_type = u8;
678+
$body
679+
} else if n_elements <= u16::MAX as usize {
680+
type $offset_type = u16;
681+
$body
682+
} else if n_elements <= u32::MAX as usize {
683+
type $offset_type = u32;
684+
$body
685+
} else {
686+
assert!(u64::try_from(n_elements).is_ok());
687+
type $offset_type = u64;
688+
$body
689+
}
690+
}};
691+
}
692+
671693
impl PType {
672694
/// Returns `true` iff this PType is an unsigned integer type
673695
#[inline]

vortex-vector/src/binaryview/vector.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,13 @@ impl<T: BinaryViewType> VectorOps for BinaryViewVector<T> {
219219
BinaryViewScalar::<T>::new(self.get(index))
220220
}
221221

222-
fn slice(&self, _range: impl RangeBounds<usize> + Clone + Debug) -> Self {
223-
todo!()
222+
fn slice(&self, range: impl RangeBounds<usize> + Clone + Debug) -> Self {
223+
BinaryViewVector {
224+
views: self.views.slice(range.clone()),
225+
buffers: self.buffers().clone(),
226+
validity: self.validity.slice(range),
227+
_marker: self._marker,
228+
}
224229
}
225230

226231
fn clear(&mut self) {

0 commit comments

Comments
 (0)