Skip to content

Commit 56e6f3f

Browse files
committed
wip
Signed-off-by: Alexander Droste <[email protected]>
1 parent 7c3a1e7 commit 56e6f3f

File tree

1 file changed

+52
-27
lines changed
  • vortex-array/src/arrays/list/compute

1 file changed

+52
-27
lines changed

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use num_traits::ToPrimitive;
45
use vortex_buffer::BitBufferMut;
56
use vortex_dtype::IntegerPType;
67
use vortex_dtype::Nullability;
@@ -40,21 +41,44 @@ impl TakeKernel for ListVTable {
4041

4142
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
4243
match_each_integer_ptype!(indices.ptype(), |I| {
43-
_take::<I, O>(
44-
array,
45-
offsets.as_slice::<O>(),
46-
&indices,
47-
array.validity_mask(),
48-
indices.validity_mask(),
49-
)
44+
let offsets_slice = offsets.as_slice::<O>();
45+
let indices_slice: &[I] = indices.as_slice::<I>();
46+
47+
// Calculate total count to determine appropriate accumulation type
48+
let total_count = indices_slice
49+
.iter()
50+
.map(|idx| {
51+
let idx = idx.to_usize().unwrap_or_else(|| {
52+
vortex_panic!("Failed to convert index to usize: {}", idx)
53+
});
54+
(offsets_slice[idx + 1] - offsets_slice[idx])
55+
.to_usize()
56+
.unwrap_or_else(|| {
57+
vortex_panic!(
58+
"Failed to convert offset difference to usize: {}",
59+
offsets_slice[idx + 1] - offsets_slice[idx]
60+
)
61+
})
62+
})
63+
.sum::<usize>();
64+
65+
match_smallest_offset_type!(total_count, |AccumType| {
66+
_take::<I, O, AccumType>(
67+
array,
68+
offsets_slice,
69+
&indices,
70+
array.validity_mask(),
71+
indices.validity_mask(),
72+
)
73+
})
5074
})
5175
})
5276
}
5377
}
5478

5579
register_kernel!(TakeKernelAdapter(ListVTable).lift());
5680

57-
fn _take<I: IntegerPType, O: IntegerPType>(
81+
fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
5882
array: &ListArray,
5983
offsets: &[O],
6084
indices_array: &PrimitiveArray,
@@ -64,7 +88,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6488
let indices: &[I] = indices_array.as_slice::<I>();
6589

6690
if !indices_validity_mask.all_true() || !data_validity.all_true() {
67-
return _take_nullable::<I, O>(
91+
return _take_nullable::<I, O, AccumType>(
6892
array,
6993
offsets,
7094
indices,
@@ -74,24 +98,13 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7498
}
7599

76100
let mut new_offsets =
77-
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
101+
PrimitiveBuilder::<AccumType>::with_capacity(Nullability::NonNullable, indices.len());
78102
let mut elements_to_take =
79103
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
80104

81-
let mut current_offset = 0u64;
105+
let mut current_offset = AccumType::zero();
82106
new_offsets.append_zero();
83107

84-
// Total element count.
85-
let total_count = indices
86-
.iter()
87-
.map(|idx| {
88-
let idx = idx
89-
.to_usize()
90-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
91-
(offsets[idx + 1] - offsets[idx]).as_() as usize
92-
})
93-
.sum::<usize>();
94-
95108
for &data_idx in indices {
96109
let data_idx = data_idx
97110
.to_usize()
@@ -113,7 +126,13 @@ fn _take<I: IntegerPType, O: IntegerPType>(
113126
for i in 0..additional {
114127
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
115128
}
116-
current_offset += (stop - start).as_() as u64;
129+
current_offset += AccumType::from_usize((stop - start).to_usize().unwrap_or_else(|| {
130+
vortex_panic!(
131+
"Failed to convert offset difference to usize: {}",
132+
stop - start
133+
)
134+
}))
135+
.vortex_expect("offset conversion");
117136
new_offsets.append_value(current_offset);
118137
}
119138

@@ -133,15 +152,15 @@ fn _take<I: IntegerPType, O: IntegerPType>(
133152
.to_array())
134153
}
135154

136-
fn _take_nullable<I: IntegerPType, O: IntegerPType>(
155+
fn _take_nullable<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
137156
array: &ListArray,
138157
offsets: &[O],
139158
indices: &[I],
140159
data_validity: Mask,
141160
indices_validity: Mask,
142161
) -> VortexResult<ArrayRef> {
143162
let mut new_offsets =
144-
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
163+
PrimitiveBuilder::<AccumType>::with_capacity(Nullability::NonNullable, indices.len());
145164

146165
// This will be the indices we push down to the child array to call `take` with.
147166
//
@@ -153,7 +172,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
153172
let mut elements_to_take =
154173
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
155174

156-
let mut current_offset = 0u64;
175+
let mut current_offset = AccumType::zero();
157176
new_offsets.append_zero();
158177

159178
// Set all bits to invalid and selectively set which values are valid.
@@ -188,7 +207,13 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
188207
for i in 0..additional {
189208
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
190209
}
191-
current_offset += (stop - start).as_() as u64;
210+
current_offset += AccumType::from_usize((stop - start).to_usize().unwrap_or_else(|| {
211+
vortex_panic!(
212+
"Failed to convert offset difference to usize: {}",
213+
stop - start
214+
)
215+
}))
216+
.vortex_expect("offset conversion");
192217
new_offsets.append_value(current_offset);
193218
new_validity.set(idx);
194219
}

0 commit comments

Comments
 (0)