Skip to content

Commit 6e55318

Browse files
committed
fix: VarBin overflow on take
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 615bfef commit 6e55318

File tree

1 file changed

+134
-35
lines changed
  • vortex-array/src/arrays/varbin/compute

1 file changed

+134
-35
lines changed

vortex-array/src/arrays/varbin/compute/take.rs

Lines changed: 134 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,118 @@ impl TakeKernel for VarBinVTable {
1717
let offsets = array.offsets().to_primitive();
1818
let data = array.bytes();
1919
let indices = indices.to_primitive();
20-
match_each_integer_ptype!(offsets.ptype(), |O| {
21-
match_each_integer_ptype!(indices.ptype(), |I| {
22-
Ok(take(
20+
let array = match_each_integer_ptype!(indices.ptype(), |I| {
21+
// On take, offsets get widened to either 32- or 64-bit based on the original type,
22+
// to avoid overflow issues.
23+
match offsets.ptype() {
24+
PType::U8 => take::<I, u8, u32>(
2325
array
2426
.dtype()
2527
.clone()
2628
.union_nullability(indices.dtype().nullability()),
27-
offsets.as_slice::<O>(),
29+
offsets.as_slice::<u8>(),
2830
data.as_slice(),
2931
indices.as_slice::<I>(),
3032
array.validity_mask(),
3133
indices.validity_mask(),
32-
)?
33-
.into_array())
34-
})
35-
})
34+
),
35+
PType::U16 => take::<I, u16, u32>(
36+
array
37+
.dtype()
38+
.clone()
39+
.union_nullability(indices.dtype().nullability()),
40+
offsets.as_slice::<u16>(),
41+
data.as_slice(),
42+
indices.as_slice::<I>(),
43+
array.validity_mask(),
44+
indices.validity_mask(),
45+
),
46+
PType::U32 => take::<I, u32, u32>(
47+
array
48+
.dtype()
49+
.clone()
50+
.union_nullability(indices.dtype().nullability()),
51+
offsets.as_slice::<u32>(),
52+
data.as_slice(),
53+
indices.as_slice::<I>(),
54+
array.validity_mask(),
55+
indices.validity_mask(),
56+
),
57+
PType::U64 => take::<I, u64, u64>(
58+
array
59+
.dtype()
60+
.clone()
61+
.union_nullability(indices.dtype().nullability()),
62+
offsets.as_slice::<u64>(),
63+
data.as_slice(),
64+
indices.as_slice::<I>(),
65+
array.validity_mask(),
66+
indices.validity_mask(),
67+
),
68+
PType::I8 => take::<I, i8, i32>(
69+
array
70+
.dtype()
71+
.clone()
72+
.union_nullability(indices.dtype().nullability()),
73+
offsets.as_slice::<i8>(),
74+
data.as_slice(),
75+
indices.as_slice::<I>(),
76+
array.validity_mask(),
77+
indices.validity_mask(),
78+
),
79+
PType::I16 => take::<I, i16, i32>(
80+
array
81+
.dtype()
82+
.clone()
83+
.union_nullability(indices.dtype().nullability()),
84+
offsets.as_slice::<i16>(),
85+
data.as_slice(),
86+
indices.as_slice::<I>(),
87+
array.validity_mask(),
88+
indices.validity_mask(),
89+
),
90+
PType::I32 => take::<I, i32, i32>(
91+
array
92+
.dtype()
93+
.clone()
94+
.union_nullability(indices.dtype().nullability()),
95+
offsets.as_slice::<i32>(),
96+
data.as_slice(),
97+
indices.as_slice::<I>(),
98+
array.validity_mask(),
99+
indices.validity_mask(),
100+
),
101+
PType::I64 => take::<I, i64, i64>(
102+
array
103+
.dtype()
104+
.clone()
105+
.union_nullability(indices.dtype().nullability()),
106+
offsets.as_slice::<i64>(),
107+
data.as_slice(),
108+
indices.as_slice::<I>(),
109+
array.validity_mask(),
110+
indices.validity_mask(),
111+
),
112+
_ => unreachable!("invalid PType for offsets"),
113+
}
114+
});
115+
116+
Ok(array?.into_array())
36117
}
37118
}
38119

39120
register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
40121

41-
fn take<I: IntegerPType, O: IntegerPType>(
122+
fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
42123
dtype: DType,
43-
offsets: &[O],
124+
offsets: &[Offset],
44125
data: &[u8],
45-
indices: &[I],
126+
indices: &[Index],
46127
validity_mask: Mask,
47128
indices_validity_mask: Mask,
48129
) -> VortexResult<VarBinArray> {
49130
if !validity_mask.all_true() || !indices_validity_mask.all_true() {
50-
return Ok(take_nullable(
131+
return Ok(take_nullable::<Index, Offset, NewOffset>(
51132
dtype,
52133
offsets,
53134
data,
@@ -57,25 +138,22 @@ fn take<I: IntegerPType, O: IntegerPType>(
57138
));
58139
}
59140

60-
let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
61-
new_offsets.push(O::zero());
62-
let mut current_offset = O::zero();
141+
let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
142+
new_offsets.push(NewOffset::zero());
143+
let mut current_offset = NewOffset::zero();
63144

64145
for &idx in indices {
65146
let idx = idx
66147
.to_usize()
67148
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
68149
let start = offsets[idx];
69150
let stop = offsets[idx + 1];
70-
current_offset += stop - start;
151+
152+
current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
71153
new_offsets.push(current_offset);
72154
}
73155

74-
let mut new_data = ByteBufferMut::with_capacity(
75-
current_offset
76-
.to_usize()
77-
.vortex_expect("Failed to cast max offset to usize"),
78-
);
156+
let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
79157

80158
for idx in indices {
81159
let idx = idx
@@ -104,17 +182,17 @@ fn take<I: IntegerPType, O: IntegerPType>(
104182
}
105183
}
106184

107-
fn take_nullable<I: IntegerPType, O: IntegerPType>(
185+
fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
108186
dtype: DType,
109-
offsets: &[O],
187+
offsets: &[Offset],
110188
data: &[u8],
111-
indices: &[I],
189+
indices: &[Index],
112190
data_validity: Mask,
113191
indices_validity: Mask,
114192
) -> VarBinArray {
115-
let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
116-
new_offsets.push(O::zero());
117-
let mut current_offset = O::zero();
193+
let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
194+
new_offsets.push(NewOffset::zero());
195+
let mut current_offset = NewOffset::zero();
118196

119197
let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
120198

@@ -135,7 +213,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
135213
validity_buffer.append(true);
136214
let start = offsets[data_idx_usize];
137215
let stop = offsets[data_idx_usize + 1];
138-
current_offset += stop - start;
216+
current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
139217
new_offsets.push(current_offset);
140218
valid_indices.push(data_idx_usize);
141219
} else {
@@ -144,11 +222,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
144222
}
145223
}
146224

147-
let mut new_data = ByteBufferMut::with_capacity(
148-
current_offset
149-
.to_usize()
150-
.vortex_expect("Failed to cast max offset to usize"),
151-
);
225+
let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
152226

153227
// Second pass: copy data for valid indices only
154228
for data_idx in valid_indices {
@@ -178,12 +252,14 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
178252
#[cfg(test)]
179253
mod tests {
180254
use rstest::rstest;
255+
use vortex_buffer::{ByteBuffer, buffer};
181256
use vortex_dtype::{DType, Nullability};
182257

183-
use crate::Array;
184-
use crate::arrays::{PrimitiveArray, VarBinArray};
258+
use crate::arrays::{PrimitiveArray, VarBinArray, VarBinVTable};
185259
use crate::compute::conformance::take::test_take_conformance;
186260
use crate::compute::take;
261+
use crate::validity::Validity;
262+
use crate::{Array, IntoArray};
187263

188264
#[test]
189265
fn test_null_take() {
@@ -221,4 +297,27 @@ mod tests {
221297
fn test_take_varbin_conformance(#[case] array: VarBinArray) {
222298
test_take_conformance(array.as_ref());
223299
}
300+
301+
#[test]
302+
fn test_take_overflow() {
303+
let scream = std::iter::once("a").cycle().take(128).collect::<String>();
304+
let bytes = ByteBuffer::copy_from(scream.as_bytes());
305+
let offsets = buffer![0u8, 128u8].into_array();
306+
307+
let array = VarBinArray::new(
308+
offsets,
309+
bytes,
310+
DType::Utf8(Nullability::NonNullable),
311+
Validity::NonNullable,
312+
);
313+
314+
let indices = buffer![0u32, 0u32, 0u32].into_array();
315+
let taken = take(array.as_ref(), indices.as_ref()).unwrap();
316+
317+
let taken_str = taken.as_::<VarBinVTable>();
318+
assert_eq!(taken_str.len(), 3);
319+
assert_eq!(taken_str.bytes_at(0).as_bytes(), scream.as_bytes());
320+
assert_eq!(taken_str.bytes_at(1).as_bytes(), scream.as_bytes());
321+
assert_eq!(taken_str.bytes_at(2).as_bytes(), scream.as_bytes());
322+
}
224323
}

0 commit comments

Comments
 (0)