Skip to content

Commit bc019fb

Browse files
committed
pick smallest type
Signed-off-by: Alexander Droste <[email protected]>
1 parent e6c12db commit bc019fb

File tree

4 files changed

+92
-53
lines changed

4 files changed

+92
-53
lines changed

encodings/sparse/src/canonical.rs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use vortex_dtype::StructFields;
4141
use vortex_dtype::match_each_decimal_value_type;
4242
use vortex_dtype::match_each_integer_ptype;
4343
use vortex_dtype::match_each_native_ptype;
44+
use vortex_dtype::match_smallest_offset_type;
4445
use vortex_error::VortexError;
4546
use vortex_error::VortexExpect;
4647
use vortex_error::vortex_panic;
@@ -124,27 +125,6 @@ fn canonicalize_sparse_lists(
124125
values_dtype: Arc<DType>,
125126
nullability: Nullability,
126127
) -> Canonical {
127-
// TODO(connor): We should move this to `vortex-dtype` so that we can use this elsewhere.
128-
macro_rules! match_smallest_offset_type {
129-
($n_elements:expr, | $offset_type:ident | $body:block) => {{
130-
let n_elements = $n_elements;
131-
if n_elements <= u8::MAX as usize {
132-
type $offset_type = u8;
133-
$body
134-
} else if n_elements <= u16::MAX as usize {
135-
type $offset_type = u16;
136-
$body
137-
} else if n_elements <= u32::MAX as usize {
138-
type $offset_type = u32;
139-
$body
140-
} else {
141-
assert!(u64::try_from(n_elements).is_ok());
142-
type $offset_type = u64;
143-
$body
144-
}
145-
}};
146-
}
147-
148128
let resolved_patches = array.resolved_patches();
149129

150130
let indices = resolved_patches.indices().to_primitive();

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use vortex_buffer::BitBufferMut;
55
use vortex_dtype::IntegerPType;
66
use vortex_dtype::Nullability;
77
use vortex_dtype::match_each_integer_ptype;
8+
use vortex_dtype::match_smallest_offset_type;
89
use vortex_error::VortexExpect;
910
use vortex_error::VortexResult;
10-
use vortex_error::vortex_panic;
1111
use vortex_mask::Mask;
1212

1313
use crate::Array;
@@ -34,27 +34,33 @@ use crate::vtable::ValidityHelper;
3434
/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
3535
/// non-contiguous indices would violate this requirement.
3636
impl TakeKernel for ListVTable {
37+
#[expect(clippy::cognitive_complexity)]
3738
fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
3839
let indices = indices.to_primitive();
3940
let offsets = array.offsets().to_primitive();
41+
// This is an over-approximation of the total number of elements in the resulting array.
42+
let total_approx = array.elements().len() * indices.len();
4043

4144
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
45+
let offsets_slice = offsets.as_slice::<O>();
4246
match_each_integer_ptype!(indices.ptype(), |I| {
43-
_take::<I, O>(
44-
array,
45-
offsets.as_slice::<O>(),
46-
&indices,
47-
array.validity_mask(),
48-
indices.validity_mask(),
49-
)
47+
match_smallest_offset_type!(total_approx, |OutputOffsetType| {
48+
_take::<I, O, OutputOffsetType>(
49+
array,
50+
offsets_slice,
51+
&indices,
52+
array.validity_mask(),
53+
indices.validity_mask(),
54+
)
55+
})
5056
})
5157
})
5258
}
5359
}
5460

5561
register_kernel!(TakeKernelAdapter(ListVTable).lift());
5662

57-
fn _take<I: IntegerPType, O: IntegerPType>(
63+
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
5864
array: &ListArray,
5965
offsets: &[O],
6066
indices_array: &PrimitiveArray,
@@ -64,7 +70,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6470
let indices: &[I] = indices_array.as_slice::<I>();
6571

6672
if !indices_validity_mask.all_true() || !data_validity.all_true() {
67-
return _take_nullable::<I, O>(
73+
return _take_nullable::<I, O, OutputOffsetType>(
6874
array,
6975
offsets,
7076
indices,
@@ -73,18 +79,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7379
);
7480
}
7581

76-
let mut new_offsets =
77-
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
82+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
83+
Nullability::NonNullable,
84+
indices.len(),
85+
);
7886
let mut elements_to_take =
7987
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
8088

81-
let mut current_offset = 0u64;
89+
let mut current_offset = OutputOffsetType::zero();
8290
new_offsets.append_zero();
8391

8492
for &data_idx in indices {
85-
let data_idx = data_idx
86-
.to_usize()
87-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
93+
let data_idx: usize = data_idx.as_();
8894

8995
let start = offsets[data_idx];
9096
let stop = offsets[data_idx + 1];
@@ -94,15 +100,14 @@ fn _take<I: IntegerPType, O: IntegerPType>(
94100
// We could convert start and end to usize, but that would impose a potentially
95101
// harder constraint - now we don't care if they fit into usize as long as their
96102
// difference does.
97-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
98-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
99-
});
103+
let additional: usize = (stop - start).as_();
100104

101105
elements_to_take.reserve_exact(additional);
102106
for i in 0..additional {
103107
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
104108
}
105-
current_offset += (stop - start).as_() as u64;
109+
current_offset +=
110+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
106111
new_offsets.append_value(current_offset);
107112
}
108113

@@ -122,15 +127,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
122127
.to_array())
123128
}
124129

125-
fn _take_nullable<I: IntegerPType, O: IntegerPType>(
130+
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
126131
array: &ListArray,
127132
offsets: &[O],
128133
indices: &[I],
129134
data_validity: Mask,
130135
indices_validity: Mask,
131136
) -> VortexResult<ArrayRef> {
132-
let mut new_offsets =
133-
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
137+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
138+
Nullability::NonNullable,
139+
indices.len(),
140+
);
134141

135142
// This will be the indices we push down to the child array to call `take` with.
136143
//
@@ -142,7 +149,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
142149
let mut elements_to_take =
143150
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
144151

145-
let mut current_offset = 0u64;
152+
let mut current_offset = OutputOffsetType::zero();
146153
new_offsets.append_zero();
147154

148155
// Set all bits to invalid and selectively set which values are valid.
@@ -155,9 +162,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
155162
continue;
156163
}
157164

158-
let data_idx = data_idx
159-
.to_usize()
160-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
165+
let data_idx: usize = data_idx.as_();
161166

162167
if !data_validity.value(data_idx) {
163168
new_offsets.append_value(current_offset);
@@ -169,15 +174,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
169174
let stop = offsets[data_idx + 1];
170175

171176
// See the note it the `take` on the reasoning
172-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
173-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
174-
});
177+
let additional: usize = (stop - start).as_();
175178

176179
elements_to_take.reserve_exact(additional);
177180
for i in 0..additional {
178181
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
179182
}
180-
current_offset += (stop - start).as_() as u64;
183+
current_offset +=
184+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
181185
new_offsets.append_value(current_offset);
182186
new_validity.set(idx);
183187
}
@@ -201,6 +205,9 @@ mod test {
201205
use vortex_dtype::DType;
202206
use vortex_dtype::Nullability;
203207
use vortex_dtype::PType::I32;
208+
use vortex_dtype::PType::U8;
209+
use vortex_dtype::PType::U16;
210+
use vortex_dtype::PType::U32;
204211
use vortex_scalar::Scalar;
205212

206213
use crate::Array;
@@ -455,4 +462,34 @@ mod test {
455462
assert!(result_view.is_invalid(1));
456463
assert!(result_view.is_valid(2));
457464
}
465+
466+
#[rstest]
467+
#[case(10, U8)]
468+
#[case(300, U16)]
469+
#[case(70000, U32)]
470+
fn test_output_offset_type_selection(
471+
#[case] element_count: u32,
472+
#[case] expected_ptype: vortex_dtype::PType,
473+
) {
474+
let elements: Vec<_> = (0..element_count).collect();
475+
let elements_array = PrimitiveArray::from_iter(elements).to_array();
476+
477+
let mut offsets = Vec::with_capacity((element_count + 1) as usize);
478+
for idx in 0..=element_count {
479+
offsets.push(idx as u64);
480+
}
481+
let offsets_array = PrimitiveArray::from_iter(offsets).to_array();
482+
483+
let list = ListArray::try_new(elements_array, offsets_array, Validity::NonNullable)
484+
.unwrap()
485+
.to_array();
486+
487+
let indices: Vec<u32> = (0..element_count).collect();
488+
let result = take(&list, &PrimitiveArray::from_iter(indices).to_array()).unwrap();
489+
490+
assert_eq!(
491+
result.to_listview().offsets().dtype().as_ptype(),
492+
expected_ptype
493+
);
494+
}
458495
}

vortex-dtype/src/ptype.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,28 @@ macro_rules! match_each_native_simd_ptype {
668668
}};
669669
}
670670

671+
/// Macro to match the smallest offset type for a given value
672+
#[macro_export]
673+
macro_rules! match_smallest_offset_type {
674+
($n_elements:expr, | $offset_type:ident | $body:block) => {{
675+
let n_elements = $n_elements;
676+
if n_elements <= u8::MAX as usize {
677+
type $offset_type = u8;
678+
$body
679+
} else if n_elements <= u16::MAX as usize {
680+
type $offset_type = u16;
681+
$body
682+
} else if n_elements <= u32::MAX as usize {
683+
type $offset_type = u32;
684+
$body
685+
} else {
686+
assert!(u64::try_from(n_elements).is_ok());
687+
type $offset_type = u64;
688+
$body
689+
}
690+
}};
691+
}
692+
671693
impl PType {
672694
/// Returns `true` iff this PType is an unsigned integer type
673695
#[inline]

vortex-io/src/file/driver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ impl State {
125125
}
126126
}
127127

128-
#[allow(clippy::cognitive_complexity)]
128+
#[expect(clippy::cognitive_complexity)]
129129
fn on_event(&mut self, event: ReadEvent) {
130130
tracing::debug!(?event, "Received ReadEvent");
131131
match event {

0 commit comments

Comments
 (0)