Skip to content

Commit 07f5039

Browse files
authored
fix: VarBin overflow on take (#5361)
Fix for #5347 Naively, on take we widen to either 32 or 64-bit offsets based on the input offsets PType. This is similar to what arrow-rs does. I think it's hard to be more intelligent than this without doing the take in 2 passes. --------- Signed-off-by: Andrew Duffy <[email protected]>
1 parent 1f50f1c commit 07f5039

File tree

1 file changed

+118
-39
lines changed
  • vortex-array/src/arrays/varbin/compute

1 file changed

+118
-39
lines changed

vortex-array/src/arrays/varbin/compute/take.rs

Lines changed: 118 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,98 @@ impl TakeKernel for VarBinVTable {
1717
let offsets = array.offsets().to_primitive();
1818
let data = array.bytes();
1919
let indices = indices.to_primitive();
20-
match_each_integer_ptype!(offsets.ptype(), |O| {
21-
match_each_integer_ptype!(indices.ptype(), |I| {
22-
Ok(take(
23-
array
24-
.dtype()
25-
.clone()
26-
.union_nullability(indices.dtype().nullability()),
27-
offsets.as_slice::<O>(),
20+
let dtype = array
21+
.dtype()
22+
.clone()
23+
.union_nullability(indices.dtype().nullability());
24+
let array = match_each_integer_ptype!(indices.ptype(), |I| {
25+
// On take, offsets get widened to either 32- or 64-bit based on the original type,
26+
// to avoid overflow issues.
27+
match offsets.ptype() {
28+
PType::U8 => take::<I, u8, u32>(
29+
dtype,
30+
offsets.as_slice::<u8>(),
2831
data.as_slice(),
2932
indices.as_slice::<I>(),
3033
array.validity_mask(),
3134
indices.validity_mask(),
32-
)?
33-
.into_array())
34-
})
35-
})
35+
),
36+
PType::U16 => take::<I, u16, u32>(
37+
dtype,
38+
offsets.as_slice::<u16>(),
39+
data.as_slice(),
40+
indices.as_slice::<I>(),
41+
array.validity_mask(),
42+
indices.validity_mask(),
43+
),
44+
PType::U32 => take::<I, u32, u32>(
45+
dtype,
46+
offsets.as_slice::<u32>(),
47+
data.as_slice(),
48+
indices.as_slice::<I>(),
49+
array.validity_mask(),
50+
indices.validity_mask(),
51+
),
52+
PType::U64 => take::<I, u64, u64>(
53+
dtype,
54+
offsets.as_slice::<u64>(),
55+
data.as_slice(),
56+
indices.as_slice::<I>(),
57+
array.validity_mask(),
58+
indices.validity_mask(),
59+
),
60+
PType::I8 => take::<I, i8, i32>(
61+
dtype,
62+
offsets.as_slice::<i8>(),
63+
data.as_slice(),
64+
indices.as_slice::<I>(),
65+
array.validity_mask(),
66+
indices.validity_mask(),
67+
),
68+
PType::I16 => take::<I, i16, i32>(
69+
dtype,
70+
offsets.as_slice::<i16>(),
71+
data.as_slice(),
72+
indices.as_slice::<I>(),
73+
array.validity_mask(),
74+
indices.validity_mask(),
75+
),
76+
PType::I32 => take::<I, i32, i32>(
77+
dtype,
78+
offsets.as_slice::<i32>(),
79+
data.as_slice(),
80+
indices.as_slice::<I>(),
81+
array.validity_mask(),
82+
indices.validity_mask(),
83+
),
84+
PType::I64 => take::<I, i64, i64>(
85+
dtype,
86+
offsets.as_slice::<i64>(),
87+
data.as_slice(),
88+
indices.as_slice::<I>(),
89+
array.validity_mask(),
90+
indices.validity_mask(),
91+
),
92+
_ => unreachable!("invalid PType for offsets"),
93+
}
94+
});
95+
96+
Ok(array?.into_array())
3697
}
3798
}
3899

39100
register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
40101

41-
fn take<I: IntegerPType, O: IntegerPType>(
102+
fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
42103
dtype: DType,
43-
offsets: &[O],
104+
offsets: &[Offset],
44105
data: &[u8],
45-
indices: &[I],
106+
indices: &[Index],
46107
validity_mask: Mask,
47108
indices_validity_mask: Mask,
48109
) -> VortexResult<VarBinArray> {
49110
if !validity_mask.all_true() || !indices_validity_mask.all_true() {
50-
return Ok(take_nullable(
111+
return Ok(take_nullable::<Index, Offset, NewOffset>(
51112
dtype,
52113
offsets,
53114
data,
@@ -57,25 +118,22 @@ fn take<I: IntegerPType, O: IntegerPType>(
57118
));
58119
}
59120

60-
let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
61-
new_offsets.push(O::zero());
62-
let mut current_offset = O::zero();
121+
let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
122+
new_offsets.push(NewOffset::zero());
123+
let mut current_offset = NewOffset::zero();
63124

64125
for &idx in indices {
65126
let idx = idx
66127
.to_usize()
67128
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
68129
let start = offsets[idx];
69130
let stop = offsets[idx + 1];
70-
current_offset += stop - start;
131+
132+
current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
71133
new_offsets.push(current_offset);
72134
}
73135

74-
let mut new_data = ByteBufferMut::with_capacity(
75-
current_offset
76-
.to_usize()
77-
.vortex_expect("Failed to cast max offset to usize"),
78-
);
136+
let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
79137

80138
for idx in indices {
81139
let idx = idx
@@ -104,17 +162,17 @@ fn take<I: IntegerPType, O: IntegerPType>(
104162
}
105163
}
106164

107-
fn take_nullable<I: IntegerPType, O: IntegerPType>(
165+
fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
108166
dtype: DType,
109-
offsets: &[O],
167+
offsets: &[Offset],
110168
data: &[u8],
111-
indices: &[I],
169+
indices: &[Index],
112170
data_validity: Mask,
113171
indices_validity: Mask,
114172
) -> VarBinArray {
115-
let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
116-
new_offsets.push(O::zero());
117-
let mut current_offset = O::zero();
173+
let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
174+
new_offsets.push(NewOffset::zero());
175+
let mut current_offset = NewOffset::zero();
118176

119177
let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
120178

@@ -135,7 +193,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
135193
validity_buffer.append(true);
136194
let start = offsets[data_idx_usize];
137195
let stop = offsets[data_idx_usize + 1];
138-
current_offset += stop - start;
196+
current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
139197
new_offsets.push(current_offset);
140198
valid_indices.push(data_idx_usize);
141199
} else {
@@ -144,11 +202,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
144202
}
145203
}
146204

147-
let mut new_data = ByteBufferMut::with_capacity(
148-
current_offset
149-
.to_usize()
150-
.vortex_expect("Failed to cast max offset to usize"),
151-
);
205+
let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
152206

153207
// Second pass: copy data for valid indices only
154208
for data_idx in valid_indices {
@@ -178,12 +232,14 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
178232
#[cfg(test)]
179233
mod tests {
180234
use rstest::rstest;
235+
use vortex_buffer::{ByteBuffer, buffer};
181236
use vortex_dtype::{DType, Nullability};
182237

183-
use crate::Array;
184-
use crate::arrays::{PrimitiveArray, VarBinArray};
238+
use crate::arrays::{PrimitiveArray, VarBinArray, VarBinVTable};
185239
use crate::compute::conformance::take::test_take_conformance;
186240
use crate::compute::take;
241+
use crate::validity::Validity;
242+
use crate::{Array, IntoArray};
187243

188244
#[test]
189245
fn test_null_take() {
@@ -221,4 +277,27 @@ mod tests {
221277
fn test_take_varbin_conformance(#[case] array: VarBinArray) {
222278
test_take_conformance(array.as_ref());
223279
}
280+
281+
#[test]
282+
fn test_take_overflow() {
283+
let scream = std::iter::once("a").cycle().take(128).collect::<String>();
284+
let bytes = ByteBuffer::copy_from(scream.as_bytes());
285+
let offsets = buffer![0u8, 128u8].into_array();
286+
287+
let array = VarBinArray::new(
288+
offsets,
289+
bytes,
290+
DType::Utf8(Nullability::NonNullable),
291+
Validity::NonNullable,
292+
);
293+
294+
let indices = buffer![0u32, 0u32, 0u32].into_array();
295+
let taken = take(array.as_ref(), indices.as_ref()).unwrap();
296+
297+
let taken_str = taken.as_::<VarBinVTable>();
298+
assert_eq!(taken_str.len(), 3);
299+
assert_eq!(taken_str.bytes_at(0).as_bytes(), scream.as_bytes());
300+
assert_eq!(taken_str.bytes_at(1).as_bytes(), scream.as_bytes());
301+
assert_eq!(taken_str.bytes_at(2).as_bytes(), scream.as_bytes());
302+
}
224303
}

0 commit comments

Comments
 (0)