@@ -163,7 +163,7 @@ impl<'de> Deserialize<'de> for ScalarValue {
163163 where
164164 A : SeqAccess < ' v > ,
165165 {
166- let mut elems = vec ! [ ] ;
166+ let mut elems = Vec :: with_capacity ( seq . size_hint ( ) . unwrap_or_default ( ) ) ;
167167 while let Some ( e) = seq. next_element :: < ScalarValue > ( ) ? {
168168 elems. push ( e) ;
169169 }
@@ -197,46 +197,57 @@ impl Serialize for PValue {
197197 }
198198}
199199
200- impl < ' de > Deserialize < ' de > for PValue {
201- fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
202- where
203- D : Deserializer < ' de > ,
204- {
205- ScalarValue :: deserialize ( deserializer)
206- . and_then ( |scalar| scalar. 0 . as_pvalue ( ) . map_err ( Error :: custom) )
207- . and_then ( |pvalue| {
208- pvalue. ok_or_else ( || Error :: custom ( "Expected a non-null primitive scalar value" ) )
209- } )
210- }
211- }
212-
213200#[ cfg( test) ]
214201mod tests {
215- use std:: mem:: discriminant;
216202 use std:: sync:: Arc ;
217203
218204 use flexbuffers:: { FlexbufferSerializer , Reader } ;
219205 use rstest:: rstest;
220- use vortex_dtype:: { Nullability , PType } ;
206+ use vortex_dtype:: half:: f16;
207+ use vortex_dtype:: { DType , FieldDType , Nullability , PType , StructDType } ;
221208
222209 use super :: * ;
223210 use crate :: Scalar ;
224211
225212 #[ rstest]
226- #[ case( Scalar :: binary( ByteBuffer :: copy_from( b"hello" ) , Nullability :: NonNullable ) . into_value( ) ) ]
227- #[ case( Scalar :: utf8( "hello" , Nullability :: NonNullable ) . into_value( ) ) ]
228- #[ case( Scalar :: primitive( 1u8 , Nullability :: NonNullable ) . into_value( ) ) ]
229- #[ case( Scalar :: primitive( f32 :: from_bits( u32 :: from_le_bytes( [ 0xFFu8 , 0x8A , 0xF9 , 0xFF ] ) ) , Nullability :: NonNullable ) . into_value( ) ) ]
230- #[ case( Scalar :: list( Arc :: new( PType :: U8 . into( ) ) , vec![ Scalar :: primitive( 1u8 , Nullability :: NonNullable ) ] , Nullability :: NonNullable ) . into_value( ) ) ]
231- fn test_scalar_value_serde_roundtrip ( #[ case] scalar_value : ScalarValue ) {
213+ #[ case( Scalar :: binary( ByteBuffer :: copy_from( b"hello" ) , Nullability :: NonNullable ) ) ]
214+ #[ case( Scalar :: utf8( "hello" , Nullability :: NonNullable ) ) ]
215+ #[ case( Scalar :: primitive( 1u8 , Nullability :: NonNullable ) ) ]
216+ #[ case( Scalar :: primitive( f32 :: from_bits( u32 :: from_le_bytes( [ 0xFFu8 , 0x8A , 0xF9 , 0xFF ] ) ) , Nullability :: NonNullable ) ) ]
217+ #[ case( Scalar :: list( Arc :: new( PType :: U8 . into( ) ) , vec![ Scalar :: primitive( 1u8 , Nullability :: NonNullable ) ] , Nullability :: NonNullable ) ) ]
218+ #[ case( Scalar :: struct_( DType :: Struct (
219+ Arc :: new( StructDType :: from_iter( [
220+ ( "a" , FieldDType :: from( DType :: Primitive ( PType :: U32 , Nullability :: NonNullable ) ) ) ,
221+ ( "b" , FieldDType :: from( DType :: Primitive ( PType :: F16 , Nullability :: NonNullable ) ) ) ,
222+ ] ) ) ,
223+ Nullability :: NonNullable ) ,
224+ vec![
225+ Scalar :: primitive( 23592960 , Nullability :: NonNullable ) ,
226+ Scalar :: primitive( f16:: from_bits( 0 ) , Nullability :: NonNullable ) ,
227+ ] ,
228+ ) ) ]
229+ #[ case( Scalar :: struct_( DType :: Struct (
230+ Arc :: new( StructDType :: from_iter( [
231+ ( "a" , FieldDType :: from( DType :: Primitive ( PType :: U64 , Nullability :: NonNullable ) ) ) ,
232+ ( "b" , FieldDType :: from( DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ) ) ,
233+ ( "c" , FieldDType :: from( DType :: Primitive ( PType :: F16 , Nullability :: NonNullable ) ) ) ,
234+ ] ) ) ,
235+ Nullability :: NonNullable ) ,
236+ vec![
237+ Scalar :: primitive( 415118687234i64 , Nullability :: NonNullable ) ,
238+ Scalar :: primitive( 0.0f32 , Nullability :: NonNullable ) ,
239+ Scalar :: primitive( f16:: from_bits( 0 ) , Nullability :: NonNullable ) ,
240+ ] ,
241+ ) ) ]
242+ fn test_scalar_value_serde_roundtrip ( #[ case] scalar : Scalar ) {
232243 let mut serializer = FlexbufferSerializer :: new ( ) ;
233- scalar_value . serialize ( & mut serializer) . unwrap ( ) ;
244+ scalar . value . serialize ( & mut serializer) . unwrap ( ) ;
234245 let written = serializer. take_buffer ( ) ;
235246 let reader = Reader :: get_root ( written. as_ref ( ) ) . unwrap ( ) ;
236247 let scalar_read_back = ScalarValue :: deserialize ( reader) . unwrap ( ) ;
237248 assert_eq ! (
238- discriminant ( & scalar_value . 0 ) ,
239- discriminant ( & scalar_read_back . 0 )
249+ scalar ,
250+ Scalar :: new ( scalar . dtype ( ) . clone ( ) , scalar_read_back )
240251 ) ;
241252 }
242253}
0 commit comments