Skip to content

Commit f95caac

Browse files
authored
fix: list take offset type (#5679)
When doing a `take` on `ListArray` the resulting offsets can become larger than what fits into the current offset type of the array. E.g. taking the same list with length 200 and offset type u8 will overflow. We therefore use a u64 for the offsets in `_take` and `_take_nullable`. Addresses: #5592 Signed-off-by: Alexander Droste <[email protected]>
1 parent 062d9a1 commit f95caac

File tree

3 files changed

+101
-50
lines changed

3 files changed

+101
-50
lines changed

encodings/sparse/src/canonical.rs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use vortex_dtype::StructFields;
4141
use vortex_dtype::match_each_decimal_value_type;
4242
use vortex_dtype::match_each_integer_ptype;
4343
use vortex_dtype::match_each_native_ptype;
44+
use vortex_dtype::match_smallest_offset_type;
4445
use vortex_error::VortexError;
4546
use vortex_error::VortexExpect;
4647
use vortex_error::vortex_panic;
@@ -124,27 +125,6 @@ fn canonicalize_sparse_lists(
124125
values_dtype: Arc<DType>,
125126
nullability: Nullability,
126127
) -> Canonical {
127-
// TODO(connor): We should move this to `vortex-dtype` so that we can use this elsewhere.
128-
macro_rules! match_smallest_offset_type {
129-
($n_elements:expr, | $offset_type:ident | $body:block) => {{
130-
let n_elements = $n_elements;
131-
if n_elements <= u8::MAX as usize {
132-
type $offset_type = u8;
133-
$body
134-
} else if n_elements <= u16::MAX as usize {
135-
type $offset_type = u16;
136-
$body
137-
} else if n_elements <= u32::MAX as usize {
138-
type $offset_type = u32;
139-
$body
140-
} else {
141-
assert!(u64::try_from(n_elements).is_ok());
142-
type $offset_type = u64;
143-
$body
144-
}
145-
}};
146-
}
147-
148128
let resolved_patches = array.resolved_patches();
149129

150130
let indices = resolved_patches.indices().to_primitive();

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use vortex_buffer::BitBufferMut;
55
use vortex_dtype::IntegerPType;
66
use vortex_dtype::Nullability;
77
use vortex_dtype::match_each_integer_ptype;
8+
use vortex_dtype::match_smallest_offset_type;
89
use vortex_error::VortexExpect;
910
use vortex_error::VortexResult;
10-
use vortex_error::vortex_panic;
1111
use vortex_mask::Mask;
1212

1313
use crate::Array;
@@ -34,27 +34,33 @@ use crate::vtable::ValidityHelper;
3434
/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
3535
/// non-contiguous indices would violate this requirement.
3636
impl TakeKernel for ListVTable {
37+
#[expect(clippy::cognitive_complexity)]
3738
fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
3839
let indices = indices.to_primitive();
3940
let offsets = array.offsets().to_primitive();
41+
// This is an over-approximation of the total number of elements in the resulting array.
42+
let total_approx = array.elements().len().saturating_mul(indices.len());
4043

4144
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
45+
let offsets_slice = offsets.as_slice::<O>();
4246
match_each_integer_ptype!(indices.ptype(), |I| {
43-
_take::<I, O>(
44-
array,
45-
offsets.as_slice::<O>(),
46-
&indices,
47-
array.validity_mask(),
48-
indices.validity_mask(),
49-
)
47+
match_smallest_offset_type!(total_approx, |OutputOffsetType| {
48+
_take::<I, O, OutputOffsetType>(
49+
array,
50+
offsets_slice,
51+
&indices,
52+
array.validity_mask(),
53+
indices.validity_mask(),
54+
)
55+
})
5056
})
5157
})
5258
}
5359
}
5460

5561
register_kernel!(TakeKernelAdapter(ListVTable).lift());
5662

57-
fn _take<I: IntegerPType, O: IntegerPType>(
63+
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
5864
array: &ListArray,
5965
offsets: &[O],
6066
indices_array: &PrimitiveArray,
@@ -64,7 +70,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6470
let indices: &[I] = indices_array.as_slice::<I>();
6571

6672
if !indices_validity_mask.all_true() || !data_validity.all_true() {
67-
return _take_nullable::<I, O>(
73+
return _take_nullable::<I, O, OutputOffsetType>(
6874
array,
6975
offsets,
7076
indices,
@@ -73,17 +79,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7379
);
7480
}
7581

76-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
82+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
83+
Nullability::NonNullable,
84+
indices.len(),
85+
);
7786
let mut elements_to_take =
7887
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
7988

80-
let mut current_offset = O::zero();
89+
let mut current_offset = OutputOffsetType::zero();
8190
new_offsets.append_zero();
8291

8392
for &data_idx in indices {
84-
let data_idx = data_idx
85-
.to_usize()
86-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
93+
let data_idx: usize = data_idx.as_();
8794

8895
let start = offsets[data_idx];
8996
let stop = offsets[data_idx + 1];
@@ -93,15 +100,15 @@ fn _take<I: IntegerPType, O: IntegerPType>(
93100
// We could convert start and end to usize, but that would impose a potentially
94101
// harder constraint - now we don't care if they fit into usize as long as their
95102
// difference does.
96-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
97-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
98-
});
103+
let additional: usize = (stop - start).as_();
99104

105+
// TODO(0ax1): optimize this
100106
elements_to_take.reserve_exact(additional);
101107
for i in 0..additional {
102108
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103109
}
104-
current_offset += stop - start;
110+
current_offset +=
111+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
105112
new_offsets.append_value(current_offset);
106113
}
107114

@@ -121,14 +128,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
121128
.to_array())
122129
}
123130

124-
fn _take_nullable<I: IntegerPType, O: IntegerPType>(
131+
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
125132
array: &ListArray,
126133
offsets: &[O],
127134
indices: &[I],
128135
data_validity: Mask,
129136
indices_validity: Mask,
130137
) -> VortexResult<ArrayRef> {
131-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
138+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
139+
Nullability::NonNullable,
140+
indices.len(),
141+
);
132142

133143
// This will be the indices we push down to the child array to call `take` with.
134144
//
@@ -140,7 +150,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
140150
let mut elements_to_take =
141151
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142152

143-
let mut current_offset = O::zero();
153+
let mut current_offset = OutputOffsetType::zero();
144154
new_offsets.append_zero();
145155

146156
// Set all bits to invalid and selectively set which values are valid.
@@ -153,9 +163,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
153163
continue;
154164
}
155165

156-
let data_idx = data_idx
157-
.to_usize()
158-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
166+
let data_idx: usize = data_idx.as_();
159167

160168
if !data_validity.value(data_idx) {
161169
new_offsets.append_value(current_offset);
@@ -167,15 +175,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
167175
let stop = offsets[data_idx + 1];
168176

169177
// See the note it the `take` on the reasoning
170-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
171-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
172-
});
178+
let additional: usize = (stop - start).as_();
173179

174180
elements_to_take.reserve_exact(additional);
175181
for i in 0..additional {
176182
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
177183
}
178-
current_offset += stop - start;
184+
current_offset +=
185+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
179186
new_offsets.append_value(current_offset);
180187
new_validity.set(idx);
181188
}
@@ -411,4 +418,46 @@ mod test {
411418
fn test_take_list_conformance(#[case] list: ListArray) {
412419
test_take_conformance(list.as_ref());
413420
}
421+
422+
#[test]
423+
fn test_u64_offset_accumulation_non_nullable() {
424+
let elements = buffer![0i32; 200].into_array();
425+
let offsets = buffer![0u8, 200].into_array();
426+
let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
427+
.unwrap()
428+
.to_array();
429+
430+
// Take the same large list twice - would overflow u8 but works with u64.
431+
let idx = buffer![0u8, 0].into_array();
432+
let result = take(&list, &idx).unwrap();
433+
434+
assert_eq!(result.len(), 2);
435+
436+
let result_view = result.to_listview();
437+
assert_eq!(result_view.len(), 2);
438+
assert!(result_view.is_valid(0));
439+
assert!(result_view.is_valid(1));
440+
}
441+
442+
#[test]
443+
fn test_u64_offset_accumulation_nullable() {
444+
let elements = buffer![0i32; 150].into_array();
445+
let offsets = buffer![0u8, 150, 150].into_array();
446+
let validity = BoolArray::from_iter(vec![true, false]).to_array();
447+
let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
448+
.unwrap()
449+
.to_array();
450+
451+
// Take the same large list twice - would overflow u8 but works with u64.
452+
let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
453+
let result = take(&list, &idx).unwrap();
454+
455+
assert_eq!(result.len(), 3);
456+
457+
let result_view = result.to_listview();
458+
assert_eq!(result_view.len(), 3);
459+
assert!(result_view.is_valid(0));
460+
assert!(result_view.is_invalid(1));
461+
assert!(result_view.is_valid(2));
462+
}
414463
}

vortex-dtype/src/ptype.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,28 @@ macro_rules! match_each_native_simd_ptype {
668668
}};
669669
}
670670

671+
/// Macro to match the smallest offset type for a given value
672+
#[macro_export]
673+
macro_rules! match_smallest_offset_type {
674+
($n_elements:expr, | $offset_type:ident | $body:block) => {{
675+
let n_elements = $n_elements;
676+
if n_elements <= u8::MAX as usize {
677+
type $offset_type = u8;
678+
$body
679+
} else if n_elements <= u16::MAX as usize {
680+
type $offset_type = u16;
681+
$body
682+
} else if n_elements <= u32::MAX as usize {
683+
type $offset_type = u32;
684+
$body
685+
} else {
686+
assert!(u64::try_from(n_elements).is_ok());
687+
type $offset_type = u64;
688+
$body
689+
}
690+
}};
691+
}
692+
671693
impl PType {
672694
/// Returns `true` iff this PType is an unsigned integer type
673695
#[inline]

0 commit comments

Comments
 (0)