Skip to content

Commit f77fa63

Browse files
committed
wip
Signed-off-by: Alexander Droste <[email protected]>
1 parent de46787 commit f77fa63

File tree

1 file changed

+52
-19
lines changed
  • vortex-array/src/arrays/list/compute

1 file changed

+52
-19
lines changed

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,17 @@ impl TakeKernel for ListVTable {
4343
match_each_integer_ptype!(indices.ptype(), |I| {
4444
let offsets_slice = offsets.as_slice::<O>();
4545
let indices_slice: &[I] = indices.as_slice::<I>();
46-
let approx_total_count = indices_slice
46+
let total_element_count = indices_slice
4747
.iter()
4848
.map(|idx| {
4949
let idx: usize = idx.as_();
50-
let length: usize = (offsets_slice[idx + 1] - offsets_slice[idx]).as_();
51-
length
50+
let diff: usize = (offsets_slice[idx + 1] - offsets_slice[idx]).as_();
51+
diff
5252
})
53-
.max()
54-
.unwrap_or(0);
53+
.sum::<usize>();
5554

56-
match_smallest_offset_type!(approx_total_count, |AccumType| {
57-
_take::<I, O, AccumType>(
55+
match_smallest_offset_type!(total_element_count, |OutputOffsetType| {
56+
_take::<I, O, OutputOffsetType>(
5857
array,
5958
offsets_slice,
6059
&indices,
@@ -69,7 +68,7 @@ impl TakeKernel for ListVTable {
6968

7069
register_kernel!(TakeKernelAdapter(ListVTable).lift());
7170

72-
fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
71+
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
7372
array: &ListArray,
7473
offsets: &[O],
7574
indices_array: &PrimitiveArray,
@@ -79,7 +78,7 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
7978
let indices: &[I] = indices_array.as_slice::<I>();
8079

8180
if !indices_validity_mask.all_true() || !data_validity.all_true() {
82-
return _take_nullable::<I, O, AccumType>(
81+
return _take_nullable::<I, O, OutputOffsetType>(
8382
array,
8483
offsets,
8584
indices,
@@ -88,12 +87,14 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
8887
);
8988
}
9089

91-
let mut new_offsets =
92-
PrimitiveBuilder::<AccumType>::with_capacity(Nullability::NonNullable, indices.len());
90+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
91+
Nullability::NonNullable,
92+
indices.len(),
93+
);
9394
let mut elements_to_take =
9495
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
9596

96-
let mut current_offset = AccumType::zero();
97+
let mut current_offset = OutputOffsetType::zero();
9798
new_offsets.append_zero();
9899

99100
for &data_idx in indices {
@@ -114,7 +115,7 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
114115
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
115116
}
116117
current_offset +=
117-
AccumType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
118+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
118119
new_offsets.append_value(current_offset);
119120
}
120121

@@ -134,15 +135,17 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
134135
.to_array())
135136
}
136137

137-
fn _take_nullable<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
138+
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
138139
array: &ListArray,
139140
offsets: &[O],
140141
indices: &[I],
141142
data_validity: Mask,
142143
indices_validity: Mask,
143144
) -> VortexResult<ArrayRef> {
144-
let mut new_offsets =
145-
PrimitiveBuilder::<AccumType>::with_capacity(Nullability::NonNullable, indices.len());
145+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
146+
Nullability::NonNullable,
147+
indices.len(),
148+
);
146149

147150
// This will be the indices we push down to the child array to call `take` with.
148151
//
@@ -154,7 +157,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
154157
let mut elements_to_take =
155158
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
156159

157-
let mut current_offset = AccumType::zero();
160+
let mut current_offset = OutputOffsetType::zero();
158161
new_offsets.append_zero();
159162

160163
// Set all bits to invalid and selectively set which values are valid.
@@ -186,7 +189,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
186189
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
187190
}
188191
current_offset +=
189-
AccumType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
192+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
190193
new_offsets.append_value(current_offset);
191194
new_validity.set(idx);
192195
}
@@ -209,7 +212,7 @@ mod test {
209212
use vortex_buffer::buffer;
210213
use vortex_dtype::DType;
211214
use vortex_dtype::Nullability;
212-
use vortex_dtype::PType::I32;
215+
use vortex_dtype::PType::{I32, U8, U16, U32};
213216
use vortex_scalar::Scalar;
214217

215218
use crate::Array;
@@ -464,4 +467,34 @@ mod test {
464467
assert!(result_view.is_invalid(1));
465468
assert!(result_view.is_valid(2));
466469
}
470+
471+
#[rstest]
472+
#[case(10, U8)]
473+
#[case(300, U16)]
474+
#[case(70000, U32)]
475+
fn test_output_offset_type_selection(
476+
#[case] element_count: usize,
477+
#[case] expected_ptype: vortex_dtype::PType,
478+
) {
479+
let elements: Vec<i32> = (0..element_count as i32).collect();
480+
let elements_array = PrimitiveArray::from_iter(elements).to_array();
481+
482+
let mut offsets = Vec::with_capacity(element_count + 1);
483+
for idx in 0..element_count {
484+
offsets.push(idx as u64);
485+
}
486+
let offsets_array = PrimitiveArray::from_iter(offsets).to_array();
487+
488+
let list = ListArray::try_new(elements_array, offsets_array, Validity::NonNullable)
489+
.unwrap()
490+
.to_array();
491+
492+
let indices: Vec<u32> = (0..element_count as u32).collect();
493+
let result = take(&list, &PrimitiveArray::from_iter(indices).to_array()).unwrap();
494+
495+
assert_eq!(
496+
result.to_listview().offsets().dtype().as_ptype(),
497+
expected_ptype
498+
);
499+
}
467500
}

0 commit comments

Comments
 (0)