11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use num_traits:: NumCast ;
5+ use vortex_buffer:: Buffer ;
6+ use vortex_buffer:: BufferMut ;
47use vortex_dtype:: DType ;
58use vortex_dtype:: NativePType ;
9+ use vortex_dtype:: match_each_native_ptype;
610use vortex_error:: VortexResult ;
711use vortex_error:: vortex_bail;
12+ use vortex_error:: vortex_err;
13+ use vortex_mask:: AllOr ;
14+ use vortex_mask:: Mask ;
815use vortex_vector:: Scalar ;
916use vortex_vector:: ScalarOps ;
1017use vortex_vector:: Vector ;
1118use vortex_vector:: VectorOps ;
1219use vortex_vector:: primitive:: PScalar ;
1320use vortex_vector:: primitive:: PVector ;
21+ use vortex_vector:: primitive:: PrimitiveScalar ;
22+ use vortex_vector:: primitive:: PrimitiveVector ;
1423
1524use crate :: cast:: Cast ;
1625use crate :: cast:: try_cast_scalar_common;
@@ -19,26 +28,25 @@ use crate::cast::try_cast_vector_common;
1928impl < T : NativePType > Cast for PVector < T > {
2029 type Output = Vector ;
2130
22- /// Casts to Primitive (same PType identity) .
31+ /// Cast a primitive vector to a different primitive type .
2332 fn cast ( & self , target_dtype : & DType ) -> VortexResult < Vector > {
2433 if let Some ( result) = try_cast_vector_common ( self , target_dtype) ? {
2534 return Ok ( result) ;
2635 }
2736
2837 match target_dtype {
29- // We're already the correct PType, and we have compatible nullability.
38+ // We have the same ` PType` and we have compatible nullability.
3039 DType :: Primitive ( target_ptype, n)
3140 if * target_ptype == T :: PTYPE && ( n. is_nullable ( ) || self . validity ( ) . all_true ( ) ) =>
3241 {
3342 Ok ( self . clone ( ) . into ( ) )
3443 }
35- // We're not the correct PType, but we do have compatible nullability.
44+ // We can possibly convert to the target ` PType` and we have compatible nullability.
3645 DType :: Primitive ( target_ptype, n) if n. is_nullable ( ) || self . validity ( ) . all_true ( ) => {
37- vortex_bail ! (
38- "Casting PVector from PType {} to PType {} not yet implemented" ,
39- T :: PTYPE ,
40- target_ptype
41- ) ;
46+ match_each_native_ptype ! ( * target_ptype, |Dst | {
47+ let result = cast_pvector:: <T , Dst >( self ) ?;
48+ Ok ( PrimitiveVector :: from( result) . into( ) )
49+ } )
4250 }
4351 _ => {
4452 vortex_bail ! ( "Cannot cast PVector<{}> to {}" , T :: PTYPE , target_dtype) ;
@@ -47,33 +55,243 @@ impl<T: NativePType> Cast for PVector<T> {
4755 }
4856}
4957
58+ /// Cast a [`PVector<F>`] to a [`PVector<T>`] by converting each element.
59+ ///
60+ /// Returns an error if any valid element cannot be converted (e.g., overflow).
61+ fn cast_pvector < Src : NativePType , Dst : NativePType > (
62+ src : & PVector < Src > ,
63+ ) -> VortexResult < PVector < Dst > > {
64+ let elements: & [ Src ] = src. as_ref ( ) ;
65+ match src. validity ( ) . bit_buffer ( ) {
66+ AllOr :: All => {
67+ let mut buffer = BufferMut :: with_capacity ( elements. len ( ) ) ;
68+ for & item in elements {
69+ let converted = <Dst as NumCast >:: from ( item) . ok_or_else (
70+ || vortex_err ! ( ComputeError : "Failed to cast {} to {:?}" , item, Dst :: PTYPE ) ,
71+ ) ?;
72+ // SAFETY: We pre-allocated the required capacity.
73+ unsafe { buffer. push_unchecked ( converted) }
74+ }
75+ Ok ( PVector :: from ( buffer. freeze ( ) ) )
76+ }
77+ AllOr :: None => Ok ( PVector :: new (
78+ Buffer :: zeroed ( elements. len ( ) ) ,
79+ Mask :: new_false ( elements. len ( ) ) ,
80+ ) ) ,
81+ AllOr :: Some ( bit_buffer) => {
82+ let mut buffer = BufferMut :: with_capacity ( elements. len ( ) ) ;
83+ for ( & item, valid) in elements. iter ( ) . zip ( bit_buffer. iter ( ) ) {
84+ if valid {
85+ let converted = <Dst as NumCast >:: from ( item) . ok_or_else (
86+ || vortex_err ! ( ComputeError : "Failed to cast {} to {:?}" , item, Dst :: PTYPE ) ,
87+ ) ?;
88+ // SAFETY: We pre-allocated the required capacity.
89+ unsafe { buffer. push_unchecked ( converted) }
90+ } else {
91+ // SAFETY: We pre-allocated the required capacity.
92+ unsafe { buffer. push_unchecked ( Dst :: default ( ) ) }
93+ }
94+ }
95+ Ok ( PVector :: new ( buffer. freeze ( ) , src. validity ( ) . clone ( ) ) )
96+ }
97+ }
98+ }
99+
50100impl < T : NativePType > Cast for PScalar < T > {
51101 type Output = Scalar ;
52102
53- /// Casts to Primitive (same PType identity) .
103+ /// Cast a primitive scalar to a different primitive type .
54104 fn cast ( & self , target_dtype : & DType ) -> VortexResult < Scalar > {
55105 if let Some ( result) = try_cast_scalar_common ( self , target_dtype) ? {
56106 return Ok ( result) ;
57107 }
58108
59109 match target_dtype {
60- // We're already the correct PType, and we have compatible nullability.
110+ // We have the same ` PType` and we have compatible nullability.
61111 DType :: Primitive ( target_ptype, n)
62112 if * target_ptype == T :: PTYPE && ( n. is_nullable ( ) || self . is_valid ( ) ) =>
63113 {
64114 Ok ( self . clone ( ) . into ( ) )
65115 }
66- // We're not the correct PType, but we do have compatible nullability.
116+ // We can possibly convert to the target ` PType` and we have compatible nullability.
67117 DType :: Primitive ( target_ptype, n) if n. is_nullable ( ) || self . is_valid ( ) => {
68- vortex_bail ! (
69- "Casting PScalar from PType {} to PType {} not yet implemented" ,
70- T :: PTYPE ,
71- target_ptype
72- ) ;
118+ match_each_native_ptype ! ( * target_ptype, |Dst | {
119+ let result = match self . value( ) {
120+ None => PScalar :: null( ) ,
121+ Some ( v) => {
122+ let converted = <Dst as NumCast >:: from( v) . ok_or_else( || {
123+ vortex_err!( ComputeError : "Failed to cast {} to {:?}" , v, Dst :: PTYPE )
124+ } ) ?;
125+ PScalar :: new( Some ( converted) )
126+ }
127+ } ;
128+ Ok ( PrimitiveScalar :: from( result) . into( ) )
129+ } )
73130 }
74131 _ => {
75132 vortex_bail ! ( "Cannot cast PScalar<{}> to {}" , T :: PTYPE , target_dtype) ;
76133 }
77134 }
78135 }
79136}
137+
138+ #[ cfg( test) ]
139+ mod tests {
140+ use rstest:: rstest;
141+ use vortex_buffer:: BitBuffer ;
142+ use vortex_buffer:: buffer;
143+ use vortex_dtype:: DType ;
144+ use vortex_dtype:: Nullability ;
145+ use vortex_dtype:: PType ;
146+ use vortex_dtype:: PTypeDowncast ;
147+ use vortex_error:: VortexError ;
148+ use vortex_mask:: Mask ;
149+ use vortex_vector:: ScalarOps ;
150+ use vortex_vector:: VectorOps ;
151+ use vortex_vector:: primitive:: PScalar ;
152+ use vortex_vector:: primitive:: PVector ;
153+
154+ use crate :: cast:: Cast ;
155+
156+ #[ rstest]
157+ #[ case( PType :: U8 ) ]
158+ #[ case( PType :: U16 ) ]
159+ #[ case( PType :: U32 ) ]
160+ #[ case( PType :: U64 ) ]
161+ #[ case( PType :: I8 ) ]
162+ #[ case( PType :: I16 ) ]
163+ #[ case( PType :: I32 ) ]
164+ #[ case( PType :: I64 ) ]
165+ #[ case( PType :: F32 ) ]
166+ #[ case( PType :: F64 ) ]
167+ fn cast_u32_to_ptype ( #[ case] target : PType ) {
168+ // Use values that fit in all target types (including i8: -128..127).
169+ let vec: PVector < u32 > = buffer ! [ 0u32 , 10 , 100 ] . into ( ) ;
170+ let result = vec. cast ( & target. into ( ) ) . unwrap ( ) ;
171+ assert ! ( result. as_primitive( ) . validity( ) . all_true( ) ) ;
172+ assert_eq ! ( result. len( ) , 3 ) ;
173+ }
174+
175+ #[ test]
176+ fn cast_various_types_to_f64 ( ) {
177+ // Test casting from various primitive types to f64.
178+ let u8_vec: PVector < u8 > = buffer ! [ 0u8 , 1 , 2 , 3 , 255 ] . into ( ) ;
179+ assert ! ( u8_vec. cast( & PType :: F64 . into( ) ) . is_ok( ) ) ;
180+
181+ let u16_vec: PVector < u16 > = buffer ! [ 0u16 , 100 , 1000 ] . into ( ) ;
182+ assert ! ( u16_vec. cast( & PType :: F64 . into( ) ) . is_ok( ) ) ;
183+
184+ let u32_vec: PVector < u32 > = buffer ! [ 0u32 , 100 , 1000 , 1000000 ] . into ( ) ;
185+ assert ! ( u32_vec. cast( & PType :: F64 . into( ) ) . is_ok( ) ) ;
186+
187+ let i8_vec: PVector < i8 > = buffer ! [ 0i8 , -1 , 1 , 127 ] . into ( ) ;
188+ assert ! ( i8_vec. cast( & PType :: F64 . into( ) ) . is_ok( ) ) ;
189+
190+ let i32_vec: PVector < i32 > = buffer ! [ -1000000i32 , -1 , 0 , 1 , 1000000 ] . into ( ) ;
191+ assert ! ( i32_vec. cast( & PType :: F64 . into( ) ) . is_ok( ) ) ;
192+
193+ let f32_vec: PVector < f32 > = buffer ! [ 0.0f32 , 1.5 , -2.5 , 100.0 ] . into ( ) ;
194+ assert ! ( f32_vec. cast( & PType :: F64 . into( ) ) . is_ok( ) ) ;
195+ }
196+
197+ #[ test]
198+ fn cast_u32_u8 ( ) {
199+ let vec: PVector < u32 > = buffer ! [ 0u32 , 10 , 200 ] . into ( ) ;
200+
201+ // Cast from u32 to u8.
202+ let result = vec. cast ( & PType :: U8 . into ( ) ) . unwrap ( ) ;
203+ let p = result. into_primitive ( ) . into_u8 ( ) ;
204+ assert_eq ! ( p. as_ref( ) , & [ 0u8 , 10 , 200 ] ) ;
205+ assert ! ( p. validity( ) . all_true( ) ) ;
206+ }
207+
208+ #[ test]
209+ fn cast_u32_f32 ( ) {
210+ let vec: PVector < u32 > = buffer ! [ 0u32 , 10 , 200 ] . into ( ) ;
211+ let result = vec. cast ( & PType :: F32 . into ( ) ) . unwrap ( ) ;
212+ let p = result. into_primitive ( ) . into_f32 ( ) ;
213+ assert_eq ! ( p. as_ref( ) , & [ 0.0f32 , 10. , 200. ] ) ;
214+ }
215+
216+ #[ test]
217+ fn cast_i32_u32_overflow ( ) {
218+ let vec: PVector < i32 > = buffer ! [ -1i32 ] . into ( ) ;
219+ let error = vec. cast ( & PType :: U32 . into ( ) ) . err ( ) . unwrap ( ) ;
220+ let VortexError :: ComputeError ( s, _) = error else {
221+ unreachable ! ( )
222+ } ;
223+ assert_eq ! ( s. to_string( ) , "Failed to cast -1 to U32" ) ;
224+ }
225+
226+ #[ test]
227+ fn cast_with_invalid_nulls ( ) {
228+ // Create a vector with an invalid value at position 0 (which would overflow).
229+ let vec: PVector < i32 > = PVector :: new (
230+ buffer ! [ -1i32 , 0 , 10 ] ,
231+ Mask :: from ( BitBuffer :: from ( vec ! [ false , true , true ] ) ) ,
232+ ) ;
233+
234+ // Cast to nullable u32 should succeed because the invalid value is masked.
235+ let result = vec
236+ . cast ( & DType :: Primitive ( PType :: U32 , Nullability :: Nullable ) )
237+ . unwrap ( ) ;
238+ let p = result. into_primitive ( ) . into_u32 ( ) ;
239+ assert_eq ! ( p. as_ref( ) , & [ 0u32 , 0 , 10 ] ) ;
240+ assert_eq ! (
241+ * p. validity( ) ,
242+ Mask :: from( BitBuffer :: from( vec![ false , true , true ] ) )
243+ ) ;
244+ }
245+
246+ #[ test]
247+ fn cast_all_null_vector ( ) {
248+ let vec: PVector < i32 > = PVector :: new ( buffer ! [ -1i32 , -2 , -3 ] , Mask :: new_false ( 3 ) ) ;
249+
250+ // Cast to nullable u32 should succeed because all values are masked.
251+ let result = vec
252+ . cast ( & DType :: Primitive ( PType :: U32 , Nullability :: Nullable ) )
253+ . unwrap ( ) ;
254+ let p = result. into_primitive ( ) . into_u32 ( ) ;
255+ assert_eq ! ( p. as_ref( ) , & [ 0u32 , 0 , 0 ] ) ;
256+ assert ! ( p. validity( ) . all_false( ) ) ;
257+ }
258+
259+ #[ rstest]
260+ #[ case( 42i32 , PType :: U32 ) ]
261+ #[ case( 0i32 , PType :: U8 ) ]
262+ #[ case( 255i32 , PType :: U8 ) ]
263+ #[ case( 100i32 , PType :: F64 ) ]
264+ fn cast_scalar_valid ( #[ case] value : i32 , #[ case] target : PType ) {
265+ let scalar: PScalar < i32 > = PScalar :: new ( Some ( value) ) ;
266+ let result = scalar. cast ( & target. into ( ) ) . unwrap ( ) ;
267+ assert ! ( result. as_primitive( ) . is_valid( ) ) ;
268+ }
269+
270+ #[ test]
271+ fn cast_scalar_i32_u32_overflow ( ) {
272+ let scalar: PScalar < i32 > = PScalar :: new ( Some ( -1 ) ) ;
273+ let error = scalar. cast ( & PType :: U32 . into ( ) ) . err ( ) . unwrap ( ) ;
274+ let VortexError :: ComputeError ( s, _) = error else {
275+ unreachable ! ( )
276+ } ;
277+ assert_eq ! ( s. to_string( ) , "Failed to cast -1 to U32" ) ;
278+ }
279+
280+ #[ test]
281+ fn cast_scalar_null ( ) {
282+ let scalar: PScalar < i32 > = PScalar :: null ( ) ;
283+ let result = scalar
284+ . cast ( & DType :: Primitive ( PType :: U32 , Nullability :: Nullable ) )
285+ . unwrap ( ) ;
286+ let p = result. into_primitive ( ) . into_u32 ( ) ;
287+ assert_eq ! ( p. value( ) , None ) ;
288+ }
289+
290+ #[ test]
291+ fn cast_scalar_u32_f64 ( ) {
292+ let scalar: PScalar < u32 > = PScalar :: new ( Some ( 12345 ) ) ;
293+ let result = scalar. cast ( & PType :: F64 . into ( ) ) . unwrap ( ) ;
294+ let p = result. into_primitive ( ) . into_f64 ( ) ;
295+ assert_eq ! ( p. value( ) , Some ( 12345.0f64 ) ) ;
296+ }
297+ }
0 commit comments