Skip to content

Commit a798757

Browse files
committed
fix: list take offset type
Signed-off-by: Alexander Droste <[email protected]>
1 parent 5d85587 commit a798757

File tree

3 files changed

+101
-50
lines changed

3 files changed

+101
-50
lines changed

encodings/sparse/src/canonical.rs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use vortex_dtype::StructFields;
4141
use vortex_dtype::match_each_decimal_value_type;
4242
use vortex_dtype::match_each_integer_ptype;
4343
use vortex_dtype::match_each_native_ptype;
44+
use vortex_dtype::match_smallest_offset_type;
4445
use vortex_error::VortexError;
4546
use vortex_error::VortexExpect;
4647
use vortex_error::vortex_panic;
@@ -124,27 +125,6 @@ fn canonicalize_sparse_lists(
124125
values_dtype: Arc<DType>,
125126
nullability: Nullability,
126127
) -> Canonical {
127-
// TODO(connor): We should move this to `vortex-dtype` so that we can use this elsewhere.
128-
macro_rules! match_smallest_offset_type {
129-
($n_elements:expr, | $offset_type:ident | $body:block) => {{
130-
let n_elements = $n_elements;
131-
if n_elements <= u8::MAX as usize {
132-
type $offset_type = u8;
133-
$body
134-
} else if n_elements <= u16::MAX as usize {
135-
type $offset_type = u16;
136-
$body
137-
} else if n_elements <= u32::MAX as usize {
138-
type $offset_type = u32;
139-
$body
140-
} else {
141-
assert!(u64::try_from(n_elements).is_ok());
142-
type $offset_type = u64;
143-
$body
144-
}
145-
}};
146-
}
147-
148128
let resolved_patches = array.resolved_patches();
149129

150130
let indices = resolved_patches.indices().to_primitive();

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use vortex_buffer::BitBufferMut;
55
use vortex_dtype::IntegerPType;
66
use vortex_dtype::Nullability;
77
use vortex_dtype::match_each_integer_ptype;
8+
use vortex_dtype::match_smallest_offset_type;
89
use vortex_error::VortexExpect;
910
use vortex_error::VortexResult;
10-
use vortex_error::vortex_panic;
1111
use vortex_mask::Mask;
1212

1313
use crate::Array;
@@ -34,27 +34,33 @@ use crate::vtable::ValidityHelper;
3434
/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
3535
/// non-contiguous indices would violate this requirement.
3636
impl TakeKernel for ListVTable {
37+
#[expect(clippy::cognitive_complexity)]
3738
fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
3839
let indices = indices.to_primitive();
3940
let offsets = array.offsets().to_primitive();
41+
// This is an over-approximation of the total number of elements in the resulting array.
42+
let total_approx = array.elements().len().saturating_mul(indices.len());
4043

4144
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
45+
let offsets_slice = offsets.as_slice::<O>();
4246
match_each_integer_ptype!(indices.ptype(), |I| {
43-
_take::<I, O>(
44-
array,
45-
offsets.as_slice::<O>(),
46-
&indices,
47-
array.validity_mask(),
48-
indices.validity_mask(),
49-
)
47+
match_smallest_offset_type!(total_approx, |OutputOffsetType| {
48+
_take::<I, O, OutputOffsetType>(
49+
array,
50+
offsets_slice,
51+
&indices,
52+
array.validity_mask(),
53+
indices.validity_mask(),
54+
)
55+
})
5056
})
5157
})
5258
}
5359
}
5460

5561
register_kernel!(TakeKernelAdapter(ListVTable).lift());
5662

57-
fn _take<I: IntegerPType, O: IntegerPType>(
63+
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
5864
array: &ListArray,
5965
offsets: &[O],
6066
indices_array: &PrimitiveArray,
@@ -64,7 +70,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6470
let indices: &[I] = indices_array.as_slice::<I>();
6571

6672
if !indices_validity_mask.all_true() || !data_validity.all_true() {
67-
return _take_nullable::<I, O>(
73+
return _take_nullable::<I, O, OutputOffsetType>(
6874
array,
6975
offsets,
7076
indices,
@@ -73,17 +79,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7379
);
7480
}
7581

76-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
82+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
83+
Nullability::NonNullable,
84+
indices.len(),
85+
);
7786
let mut elements_to_take =
7887
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
7988

80-
let mut current_offset = O::zero();
89+
let mut current_offset = OutputOffsetType::zero();
8190
new_offsets.append_zero();
8291

8392
for &data_idx in indices {
84-
let data_idx = data_idx
85-
.to_usize()
86-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
93+
let data_idx: usize = data_idx.as_();
8794

8895
let start = offsets[data_idx];
8996
let stop = offsets[data_idx + 1];
@@ -93,15 +100,15 @@ fn _take<I: IntegerPType, O: IntegerPType>(
93100
// We could convert start and end to usize, but that would impose a potentially
94101
// harder constraint - now we don't care if they fit into usize as long as their
95102
// difference does.
96-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
97-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
98-
});
103+
let additional: usize = (stop - start).as_();
99104

105+
// TODO(0ax1): optimize this
100106
elements_to_take.reserve_exact(additional);
101107
for i in 0..additional {
102108
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103109
}
104-
current_offset += stop - start;
110+
current_offset +=
111+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
105112
new_offsets.append_value(current_offset);
106113
}
107114

@@ -121,14 +128,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
121128
.to_array())
122129
}
123130

124-
fn _take_nullable<I: IntegerPType, O: IntegerPType>(
131+
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
125132
array: &ListArray,
126133
offsets: &[O],
127134
indices: &[I],
128135
data_validity: Mask,
129136
indices_validity: Mask,
130137
) -> VortexResult<ArrayRef> {
131-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
138+
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
139+
Nullability::NonNullable,
140+
indices.len(),
141+
);
132142

133143
// This will be the indices we push down to the child array to call `take` with.
134144
//
@@ -140,7 +150,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
140150
let mut elements_to_take =
141151
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142152

143-
let mut current_offset = O::zero();
153+
let mut current_offset = OutputOffsetType::zero();
144154
new_offsets.append_zero();
145155

146156
// Set all bits to invalid and selectively set which values are valid.
@@ -153,9 +163,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
153163
continue;
154164
}
155165

156-
let data_idx = data_idx
157-
.to_usize()
158-
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
166+
let data_idx: usize = data_idx.as_();
159167

160168
if !data_validity.value(data_idx) {
161169
new_offsets.append_value(current_offset);
@@ -167,15 +175,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
167175
let stop = offsets[data_idx + 1];
168176

169177
// See the note it the `take` on the reasoning
170-
let additional = (stop - start).to_usize().unwrap_or_else(|| {
171-
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
172-
});
178+
let additional: usize = (stop - start).as_();
173179

174180
elements_to_take.reserve_exact(additional);
175181
for i in 0..additional {
176182
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
177183
}
178-
current_offset += stop - start;
184+
current_offset +=
185+
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
179186
new_offsets.append_value(current_offset);
180187
new_validity.set(idx);
181188
}
@@ -411,4 +418,46 @@ mod test {
411418
fn test_take_list_conformance(#[case] list: ListArray) {
412419
test_take_conformance(list.as_ref());
413420
}
421+
422+
#[test]
423+
fn test_u64_offset_accumulation_non_nullable() {
424+
let elements = buffer![0i32; 200].into_array();
425+
let offsets = buffer![0u8, 200].into_array();
426+
let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
427+
.unwrap()
428+
.to_array();
429+
430+
// Take the same large list twice - would overflow u8 but works with u64.
431+
let idx = buffer![0u8, 0].into_array();
432+
let result = take(&list, &idx).unwrap();
433+
434+
assert_eq!(result.len(), 2);
435+
436+
let result_view = result.to_listview();
437+
assert_eq!(result_view.len(), 2);
438+
assert!(result_view.is_valid(0));
439+
assert!(result_view.is_valid(1));
440+
}
441+
442+
#[test]
443+
fn test_u64_offset_accumulation_nullable() {
444+
let elements = buffer![0i32; 150].into_array();
445+
let offsets = buffer![0u8, 150, 150].into_array();
446+
let validity = BoolArray::from_iter(vec![true, false]).to_array();
447+
let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
448+
.unwrap()
449+
.to_array();
450+
451+
// Take the same large list twice - would overflow u8 but works with u64.
452+
let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
453+
let result = take(&list, &idx).unwrap();
454+
455+
assert_eq!(result.len(), 3);
456+
457+
let result_view = result.to_listview();
458+
assert_eq!(result_view.len(), 3);
459+
assert!(result_view.is_valid(0));
460+
assert!(result_view.is_invalid(1));
461+
assert!(result_view.is_valid(2));
462+
}
414463
}

vortex-dtype/src/ptype.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,28 @@ macro_rules! match_each_native_simd_ptype {
668668
}};
669669
}
670670

671+
/// Macro to match the smallest offset type for a given value
672+
#[macro_export]
673+
macro_rules! match_smallest_offset_type {
674+
($n_elements:expr, | $offset_type:ident | $body:block) => {{
675+
let n_elements = $n_elements;
676+
if n_elements <= u8::MAX as usize {
677+
type $offset_type = u8;
678+
$body
679+
} else if n_elements <= u16::MAX as usize {
680+
type $offset_type = u16;
681+
$body
682+
} else if n_elements <= u32::MAX as usize {
683+
type $offset_type = u32;
684+
$body
685+
} else {
686+
assert!(u64::try_from(n_elements).is_ok());
687+
type $offset_type = u64;
688+
$body
689+
}
690+
}};
691+
}
692+
671693
impl PType {
672694
/// Returns `true` iff this PType is an unsigned integer type
673695
#[inline]

0 commit comments

Comments
 (0)