Skip to content

Commit 7472429

Browse files
committed
fix: list take offset type
Signed-off-by: Alexander Droste <[email protected]>
1 parent 366b694 commit 7472429

File tree

1 file changed

+50
-6
lines changed
  • vortex-array/src/arrays/list/compute

1 file changed

+50
-6
lines changed

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7373
);
7474
}
7575

76-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
76+
let mut new_offsets =
77+
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
7778
let mut elements_to_take =
7879
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
7980

80-
let mut current_offset = O::zero();
81+
let mut current_offset = 0u64;
8182
new_offsets.append_zero();
8283

8384
for &data_idx in indices {
@@ -101,7 +102,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
101102
for i in 0..additional {
102103
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103104
}
104-
current_offset += stop - start;
105+
current_offset += (stop - start).as_() as u64;
105106
new_offsets.append_value(current_offset);
106107
}
107108

@@ -128,7 +129,8 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
128129
data_validity: Mask,
129130
indices_validity: Mask,
130131
) -> VortexResult<ArrayRef> {
131-
let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
132+
let mut new_offsets =
133+
PrimitiveBuilder::<u64>::with_capacity(Nullability::NonNullable, indices.len());
132134

133135
// This will be the indices we push down to the child array to call `take` with.
134136
//
@@ -140,7 +142,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
140142
let mut elements_to_take =
141143
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142144

143-
let mut current_offset = O::zero();
145+
let mut current_offset = 0u64;
144146
new_offsets.append_zero();
145147

146148
// Set all bits to invalid and selectively set which values are valid.
@@ -175,7 +177,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
175177
for i in 0..additional {
176178
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
177179
}
178-
current_offset += stop - start;
180+
current_offset += (stop - start).as_() as u64;
179181
new_offsets.append_value(current_offset);
180182
new_validity.set(idx);
181183
}
@@ -411,4 +413,46 @@ mod test {
411413
fn test_take_list_conformance(#[case] list: ListArray) {
412414
test_take_conformance(list.as_ref());
413415
}
416+
417+
#[test]
418+
fn test_u64_offset_accumulation_non_nullable() {
419+
let elements = buffer![0i32; 200].into_array();
420+
let offsets = buffer![0u8, 200].into_array();
421+
let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
422+
.unwrap()
423+
.to_array();
424+
425+
// Take the same large list twice - would overflow u8 but works with u64.
426+
let idx = buffer![0u8, 0].into_array();
427+
let result = take(&list, &idx).unwrap();
428+
429+
assert_eq!(result.len(), 2);
430+
431+
let result_view = result.to_listview();
432+
assert_eq!(result_view.len(), 2);
433+
assert!(result_view.is_valid(0));
434+
assert!(result_view.is_valid(1));
435+
}
436+
437+
#[test]
438+
fn test_u64_offset_accumulation_nullable() {
439+
let elements = buffer![0i32; 150].into_array();
440+
let offsets = buffer![0u8, 150, 150].into_array();
441+
let validity = BoolArray::from_iter(vec![true, false]).to_array();
442+
let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
443+
.unwrap()
444+
.to_array();
445+
446+
// Take the same large list twice - would overflow u8 but works with u64.
447+
let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
448+
let result = take(&list, &idx).unwrap();
449+
450+
assert_eq!(result.len(), 3);
451+
452+
let result_view = result.to_listview();
453+
assert_eq!(result_view.len(), 3);
454+
assert!(result_view.is_valid(0));
455+
assert!(result_view.is_invalid(1));
456+
assert!(result_view.is_valid(2));
457+
}
414458
}

0 commit comments

Comments
 (0)