Skip to content

Commit 83c23f3

Browse files
committed
pick smallest type
Signed-off-by: Alexander Droste <[email protected]>
1 parent 7472429 commit 83c23f3

File tree

3 files changed

+103
-53
lines changed

3 files changed

+103
-53
lines changed

encodings/sparse/src/canonical.rs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use vortex_dtype::StructFields;
4141
use vortex_dtype::match_each_decimal_value_type;
4242
use vortex_dtype::match_each_integer_ptype;
4343
use vortex_dtype::match_each_native_ptype;
44+
use vortex_dtype::match_smallest_offset_type;
4445
use vortex_error::VortexError;
4546
use vortex_error::VortexExpect;
4647
use vortex_error::vortex_panic;
@@ -124,27 +125,6 @@ fn canonicalize_sparse_lists(
124125
values_dtype: Arc<DType>,
125126
nullability: Nullability,
126127
) -> Canonical {
127-
// TODO(connor): We should move this to `vortex-dtype` so that we can use this elsewhere.
128-
macro_rules! match_smallest_offset_type {
129-
($n_elements:expr, | $offset_type:ident | $body:block) => {{
130-
let n_elements = $n_elements;
131-
if n_elements <= u8::MAX as usize {
132-
type $offset_type = u8;
133-
$body
134-
} else if n_elements <= u16::MAX as usize {
135-
type $offset_type = u16;
136-
$body
137-
} else if n_elements <= u32::MAX as usize {
138-
type $offset_type = u32;
139-
$body
140-
} else {
141-
assert!(u64::try_from(n_elements).is_ok());
142-
type $offset_type = u64;
143-
$body
144-
}
145-
}};
146-
}
147-
148128
let resolved_patches = array.resolved_patches();
149129

150130
let indices = resolved_patches.indices().to_primitive();

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use num_traits::AsPrimitive;
45
use vortex_buffer::BitBufferMut;
56
use vortex_dtype::IntegerPType;
67
use vortex_dtype::Nullability;
78
use vortex_dtype::match_each_integer_ptype;
9+
use vortex_dtype::match_smallest_offset_type;
810
use vortex_error::VortexExpect;
911
use vortex_error::VortexResult;
10-
use vortex_error::vortex_panic;
1112
use vortex_mask::Mask;
1213

1314
use crate::Array;
@@ -40,21 +41,34 @@ impl TakeKernel for ListVTable {
4041

4142
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
4243
match_each_integer_ptype!(indices.ptype(), |I| {
43-
_take::<I, O>(
44-
array,
45-
offsets.as_slice::<O>(),
46-
&indices,
47-
array.validity_mask(),
48-
indices.validity_mask(),
49-
)
44+
let offsets_slice = offsets.as_slice::<O>();
45+
let indices_slice: &[I] = indices.as_slice::<I>();
46+
let total_element_count = indices_slice
47+
.iter()
48+
.map(|idx| {
49+
let idx: usize = idx.as_();
50+
let diff: usize = (offsets_slice[idx + 1] - offsets_slice[idx]).as_();
51+
diff
52+
})
53+
.sum::<usize>();
54+
55+
match_smallest_offset_type!(total_element_count, |OutputOffsetType| {
56+
_take::<I, O, OutputOffsetType>(
57+
array,
58+
offsets_slice,
59+
&indices,
60+
array.validity_mask(),
61+
indices.validity_mask(),
62+
)
63+
})
5064
})
5165
})
5266
}
5367
}
5468

5569
register_kernel!(TakeKernelAdapter(ListVTable).lift());
5670

57-
fn _take<I: IntegerPType, O: IntegerPType>(
71+
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
5872
array: &ListArray,
5973
offsets: &[O],
6074
indices_array: &PrimitiveArray,
@@ -64,7 +78,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6478
let indices: &[I] = indices_array.as_slice::<I>();
6579

6680
if !indices_validity_mask.all_true() || !data_validity.all_true() {
67-
return _take_nullable::<I, O>(
81+
return _take_nullable::<I, O, OutputOffsetType>(
6882
array,
6983
offsets,
7084
indices,
@@ -73,18 +87,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7387
);
7488
}
7589

76-
let mut new_offsets =
77-
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
90+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
91+
Nullability::NonNullable,
92+
indices.len(),
93+
);
7894
let mut elements_to_take =
7995
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
8096

81-
let mut current_offset = 0u64;
97+
let mut current_offset = OutputOffsetType::zero();
8298
new_offsets.append_zero();
8399

84100
for &data_idx in indices {
85-
let data_idx = data_idx
86-
.to_usize()
87-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
101+
let data_idx: usize = data_idx.as_();
88102

89103
let start = offsets[data_idx];
90104
let stop = offsets[data_idx + 1];
@@ -94,15 +108,14 @@ fn _take<I: IntegerPType, O: IntegerPType>(
94108
// We could convert start and end to usize, but that would impose a potentially
95109
// harder constraint - now we don't care if they fit into usize as long as their
96110
// difference does.
97-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
98-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
99-
});
111+
let additional: usize = (stop - start).as_();
100112

101113
elements_to_take.reserve_exact(additional);
102114
for i in 0..additional {
103115
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
104116
}
105-
current_offset += (stop - start).as_() as u64;
117+
current_offset +=
118+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
106119
new_offsets.append_value(current_offset);
107120
}
108121

@@ -122,15 +135,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
122135
.to_array())
123136
}
124137

125-
fn _take_nullable<I: IntegerPType, O: IntegerPType>(
138+
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
126139
array: &ListArray,
127140
offsets: &[O],
128141
indices: &[I],
129142
data_validity: Mask,
130143
indices_validity: Mask,
131144
) -> VortexResult<ArrayRef> {
132-
let mut new_offsets =
133-
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
145+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
146+
Nullability::NonNullable,
147+
indices.len(),
148+
);
134149

135150
// This will be the indices we push down to the child array to call `take` with.
136151
//
@@ -142,7 +157,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
142157
let mut elements_to_take =
143158
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
144159

145-
let mut current_offset = 0u64;
160+
let mut current_offset = OutputOffsetType::zero();
146161
new_offsets.append_zero();
147162

148163
// Set all bits to invalid and selectively set which values are valid.
@@ -155,9 +170,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
155170
continue;
156171
}
157172

158-
let data_idx = data_idx
159-
.to_usize()
160-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
173+
let data_idx: usize = data_idx.as_();
161174

162175
if !data_validity.value(data_idx) {
163176
new_offsets.append_value(current_offset);
@@ -169,15 +182,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
169182
let stop = offsets[data_idx + 1];
170183

171184
// See the note it the `take` on the reasoning
172-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
173-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
174-
});
185+
let additional: usize = (stop - start).as_();
175186

176187
elements_to_take.reserve_exact(additional);
177188
for i in 0..additional {
178189
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
179190
}
180-
current_offset += (stop - start).as_() as u64;
191+
current_offset +=
192+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
181193
new_offsets.append_value(current_offset);
182194
new_validity.set(idx);
183195
}
@@ -200,7 +212,7 @@ mod test {
200212
use vortex_buffer::buffer;
201213
use vortex_dtype::DType;
202214
use vortex_dtype::Nullability;
203-
use vortex_dtype::PType::I32;
215+
use vortex_dtype::PType::{I32, U8, U16, U32};
204216
use vortex_scalar::Scalar;
205217

206218
use crate::Array;
@@ -455,4 +467,34 @@ mod test {
455467
assert!(result_view.is_invalid(1));
456468
assert!(result_view.is_valid(2));
457469
}
470+
471+
#[rstest]
472+
#[case(10, U8)]
473+
#[case(300, U16)]
474+
#[case(70000, U32)]
475+
fn test_output_offset_type_selection(
476+
#[case] element_count: usize,
477+
#[case] expected_ptype: vortex_dtype::PType,
478+
) {
479+
let elements: Vec<i32> = (0..element_count as i32).collect();
480+
let elements_array = PrimitiveArray::from_iter(elements).to_array();
481+
482+
let mut offsets = Vec::with_capacity(element_count + 1);
483+
for idx in 0..=element_count {
484+
offsets.push(idx as u64);
485+
}
486+
let offsets_array = PrimitiveArray::from_iter(offsets).to_array();
487+
488+
let list = ListArray::try_new(elements_array, offsets_array, Validity::NonNullable)
489+
.unwrap()
490+
.to_array();
491+
492+
let indices: Vec<u32> = (0..element_count as u32).collect();
493+
let result = take(&list, &PrimitiveArray::from_iter(indices).to_array()).unwrap();
494+
495+
assert_eq!(
496+
result.to_listview().offsets().dtype().as_ptype(),
497+
expected_ptype
498+
);
499+
}
458500
}

vortex-dtype/src/ptype.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,34 @@ macro_rules! match_each_native_simd_ptype {
668668
}};
669669
}
670670

671+
/// Macro to match the smallest offset type that can fit the given number of elements.
672+
///
673+
/// This macro selects u8, u16, u32, or u64 based on the size of `n_elements`:
674+
/// - u8 if n_elements <= 255
675+
/// - u16 if n_elements <= 65535
676+
/// - u32 if n_elements <= 4294967295
677+
/// - u64 otherwise
678+
#[macro_export]
679+
macro_rules! match_smallest_offset_type {
680+
($n_elements:expr, | $offset_type:ident | $body:block) => {{
681+
let n_elements = $n_elements;
682+
if n_elements <= u8::MAX as usize {
683+
type $offset_type = u8;
684+
$body
685+
} else if n_elements <= u16::MAX as usize {
686+
type $offset_type = u16;
687+
$body
688+
} else if n_elements <= u32::MAX as usize {
689+
type $offset_type = u32;
690+
$body
691+
} else {
692+
assert!(u64::try_from(n_elements).is_ok());
693+
type $offset_type = u64;
694+
$body
695+
}
696+
}};
697+
}
698+
671699
impl PType {
672700
/// Returns `true` iff this PType is an unsigned integer type
673701
#[inline]

0 commit comments

Comments
 (0)