Skip to content

Commit b5bf319

Browse files
authored
fix: Handle floats getting widened during flexbuffer serialization if they're part of a list containing wider type (#2926)
For integers we read all widths for all types, we need to do the same for floats.
1 parent fe12182 commit b5bf319

File tree

4 files changed

+57
-42
lines changed

4 files changed

+57
-42
lines changed

fuzz/fuzz_targets/file_io.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ use vortex_array::arrays::arbitrary::ArbitraryArray;
1111
use vortex_array::arrow::IntoArrowArray;
1212
use vortex_array::compute::{Operator, compare};
1313
use vortex_array::stream::ArrayStreamArrayExt;
14-
use vortex_array::{Array, ArrayRef, ToCanonical};
14+
use vortex_array::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
1515
use vortex_buffer::ByteBufferMut;
1616
use vortex_dtype::{DType, StructDType};
17-
use vortex_error::{VortexUnwrap, vortex_panic};
17+
use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
1818
use vortex_file::{VortexOpenOptions, VortexWriteOptions};
1919

2020
fuzz_target!(|array_data: ArbitraryArray| -> Corpus {
@@ -35,7 +35,7 @@ fuzz_target!(|array_data: ArbitraryArray| -> Corpus {
3535
.await
3636
.vortex_unwrap();
3737

38-
let output = VortexOpenOptions::in_memory()
38+
let mut output = VortexOpenOptions::in_memory()
3939
.open(full_buff)
4040
.await
4141
.vortex_unwrap()
@@ -47,27 +47,29 @@ fuzz_target!(|array_data: ArbitraryArray| -> Corpus {
4747
.await
4848
.vortex_unwrap();
4949

50-
let output = if output.is_empty() {
51-
ChunkedArray::try_new(output, array_data.dtype().clone())
52-
.vortex_unwrap()
53-
.into_array()
54-
} else {
55-
ChunkedArray::from_iter(output).into_array()
50+
let output_array = match output.len() {
51+
0 => Canonical::empty(array_data.dtype()).into_array(),
52+
1 => output.pop().vortex_expect("one output"),
53+
_ => ChunkedArray::from_iter(output).into_array(),
5654
};
5755

58-
assert_eq!(array_data.len(), output.len(), "Length was not preserved.");
56+
assert_eq!(
57+
array_data.len(),
58+
output_array.len(),
59+
"Length was not preserved."
60+
);
5961
assert_eq!(
6062
array_data.dtype(),
61-
output.dtype(),
63+
output_array.dtype(),
6264
"DTypes aren't preserved expected {}, actual {}",
6365
array_data.dtype(),
64-
output.dtype()
66+
output_array.dtype()
6567
);
6668

6769
if matches!(array_data.dtype(), DType::Struct(_, _) | DType::List(_, _)) {
68-
compare_struct(array_data, output);
70+
compare_struct(array_data, output_array);
6971
} else {
70-
let bool_result = compare(&array_data, &output, Operator::Eq)
72+
let bool_result = compare(&array_data, &output_array, Operator::Eq)
7173
.vortex_unwrap()
7274
.to_bool()
7375
.vortex_unwrap();
@@ -79,7 +81,7 @@ fuzz_target!(|array_data: ArbitraryArray| -> Corpus {
7981
vortex_panic!(
8082
"Failed to match original array {}with{}",
8183
array_data.tree_display(),
82-
output.tree_display()
84+
output_array.tree_display()
8385
);
8486
}
8587
}

vortex-scalar/src/pvalue.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ impl TryFrom<PValue> for f32 {
278278
PValue::U8(u) => Some(Self::from_bits(u as u32)),
279279
PValue::U16(u) => Some(Self::from_bits(u as u32)),
280280
PValue::U32(u) => Some(Self::from_bits(u)),
281+
PValue::U64(u) => <Self as NumCast>::from(f64::from_bits(u)),
281282
PValue::F16(f) => <Self as NumCast>::from(f),
282283
PValue::F32(f) => <Self as NumCast>::from(f),
283284
PValue::F64(f) => <Self as NumCast>::from(f),
@@ -295,6 +296,8 @@ impl TryFrom<PValue> for f16 {
295296
match value {
296297
PValue::U8(u) => Some(Self::from_bits(u as u16)),
297298
PValue::U16(u) => Some(Self::from_bits(u)),
299+
PValue::U32(u) => <Self as NumCast>::from(f32::from_bits(u)),
300+
PValue::U64(u) => <Self as NumCast>::from(f64::from_bits(u)),
298301
PValue::F16(u) => Some(u),
299302
PValue::F32(f) => <Self as NumCast>::from(f),
300303
PValue::F64(f) => <Self as NumCast>::from(f),

vortex-scalar/src/scalarvalue/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ impl ScalarValue {
4141
use vortex_error::VortexExpect;
4242

4343
let mut ser = flexbuffers::FlexbufferSerializer::new();
44-
self.0
45-
.serialize(&mut ser)
44+
self.serialize(&mut ser)
4645
.vortex_expect("Failed to serialize ScalarValue");
4746
let view = ser.view();
4847

vortex-scalar/src/serde/serde.rs

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl<'de> Deserialize<'de> for ScalarValue {
163163
where
164164
A: SeqAccess<'v>,
165165
{
166-
let mut elems = vec![];
166+
let mut elems = Vec::with_capacity(seq.size_hint().unwrap_or_default());
167167
while let Some(e) = seq.next_element::<ScalarValue>()? {
168168
elems.push(e);
169169
}
@@ -197,46 +197,57 @@ impl Serialize for PValue {
197197
}
198198
}
199199

200-
impl<'de> Deserialize<'de> for PValue {
201-
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
202-
where
203-
D: Deserializer<'de>,
204-
{
205-
ScalarValue::deserialize(deserializer)
206-
.and_then(|scalar| scalar.0.as_pvalue().map_err(Error::custom))
207-
.and_then(|pvalue| {
208-
pvalue.ok_or_else(|| Error::custom("Expected a non-null primitive scalar value"))
209-
})
210-
}
211-
}
212-
213200
#[cfg(test)]
214201
mod tests {
215-
use std::mem::discriminant;
216202
use std::sync::Arc;
217203

218204
use flexbuffers::{FlexbufferSerializer, Reader};
219205
use rstest::rstest;
220-
use vortex_dtype::{Nullability, PType};
206+
use vortex_dtype::half::f16;
207+
use vortex_dtype::{DType, FieldDType, Nullability, PType, StructDType};
221208

222209
use super::*;
223210
use crate::Scalar;
224211

225212
#[rstest]
226-
#[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable).into_value())]
227-
#[case(Scalar::utf8("hello", Nullability::NonNullable).into_value())]
228-
#[case(Scalar::primitive(1u8, Nullability::NonNullable).into_value())]
229-
#[case(Scalar::primitive(f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])), Nullability::NonNullable).into_value())]
230-
#[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable).into_value())]
231-
fn test_scalar_value_serde_roundtrip(#[case] scalar_value: ScalarValue) {
213+
#[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
214+
#[case(Scalar::utf8("hello", Nullability::NonNullable))]
215+
#[case(Scalar::primitive(1u8, Nullability::NonNullable))]
216+
#[case(Scalar::primitive(f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])), Nullability::NonNullable))]
217+
#[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable))]
218+
#[case(Scalar::struct_(DType::Struct(
219+
Arc::new(StructDType::from_iter([
220+
("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
221+
("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
222+
])),
223+
Nullability::NonNullable),
224+
vec![
225+
Scalar::primitive(23592960, Nullability::NonNullable),
226+
Scalar::primitive(f16::from_bits(0), Nullability::NonNullable),
227+
],
228+
))]
229+
#[case(Scalar::struct_(DType::Struct(
230+
Arc::new(StructDType::from_iter([
231+
("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
232+
("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
233+
("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
234+
])),
235+
Nullability::NonNullable),
236+
vec![
237+
Scalar::primitive(415118687234i64, Nullability::NonNullable),
238+
Scalar::primitive(0.0f32, Nullability::NonNullable),
239+
Scalar::primitive(f16::from_bits(0), Nullability::NonNullable),
240+
],
241+
))]
242+
fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
232243
let mut serializer = FlexbufferSerializer::new();
233-
scalar_value.serialize(&mut serializer).unwrap();
244+
scalar.value.serialize(&mut serializer).unwrap();
234245
let written = serializer.take_buffer();
235246
let reader = Reader::get_root(written.as_ref()).unwrap();
236247
let scalar_read_back = ScalarValue::deserialize(reader).unwrap();
237248
assert_eq!(
238-
discriminant(&scalar_value.0),
239-
discriminant(&scalar_read_back.0)
249+
scalar,
250+
Scalar::new(scalar.dtype().clone(), scalar_read_back)
240251
);
241252
}
242253
}

0 commit comments

Comments
 (0)