@@ -14,10 +14,10 @@ use arrow_array::{
1414use arrow_buffer:: { ScalarBuffer , i256} ;
1515use arrow_schema:: { DataType , Field , FieldRef , Fields } ;
1616use itertools:: Itertools ;
17- use num_traits:: AsPrimitive ;
17+ use num_traits:: { AsPrimitive , ToPrimitive } ;
1818use vortex_buffer:: Buffer ;
1919use vortex_dtype:: { DType , NativePType , PType } ;
20- use vortex_error:: { VortexExpect , VortexResult , vortex_bail} ;
20+ use vortex_error:: { VortexExpect , VortexResult , vortex_bail, vortex_err } ;
2121use vortex_scalar:: DecimalValueType ;
2222
2323use crate :: arrays:: {
@@ -104,8 +104,34 @@ impl Kernel for ToArrowCanonical {
104104 {
105105 to_arrow_primitive :: < Float64Type > ( array)
106106 }
107- ( Canonical :: Decimal ( array) , DataType :: Decimal128 ( ..) ) => to_arrow_decimal128 ( array) ,
108- ( Canonical :: Decimal ( array) , DataType :: Decimal256 ( ..) ) => to_arrow_decimal256 ( array) ,
107+ ( Canonical :: Decimal ( array) , DataType :: Decimal128 ( precision, scale) ) => {
108+ if array. decimal_dtype ( ) . precision ( ) != * precision
109+ || array. decimal_dtype ( ) . scale ( ) != * scale
110+ {
111+ vortex_bail ! (
112+ "ToArrowCanonical: target precision/scale {}/{} does not match array precision/scale {}/{}" ,
113+ precision,
114+ scale,
115+ array. decimal_dtype( ) . precision( ) ,
116+ array. decimal_dtype( ) . scale( )
117+ ) ;
118+ }
119+ to_arrow_decimal128 ( array)
120+ }
121+ ( Canonical :: Decimal ( array) , DataType :: Decimal256 ( precision, scale) ) => {
122+ if array. decimal_dtype ( ) . precision ( ) != * precision
123+ || array. decimal_dtype ( ) . scale ( ) != * scale
124+ {
125+ vortex_bail ! (
126+ "ToArrowCanonical: target precision/scale {}/{} does not match array precision/scale {}/{}" ,
127+ precision,
128+ scale,
129+ array. decimal_dtype( ) . precision( ) ,
130+ array. decimal_dtype( ) . scale( )
131+ ) ;
132+ }
133+ to_arrow_decimal256 ( array)
134+ }
109135 ( Canonical :: Struct ( array) , DataType :: Struct ( fields) ) => {
110136 to_arrow_struct ( array, fields. as_ref ( ) )
111137 }
@@ -188,9 +214,14 @@ fn to_arrow_decimal128(array: DecimalArray) -> VortexResult<ArrowArrayRef> {
188214 DecimalValueType :: I32 => array. buffer :: < i32 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
189215 DecimalValueType :: I64 => array. buffer :: < i64 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
190216 DecimalValueType :: I128 => array. buffer :: < i128 > ( ) ,
191- DecimalValueType :: I256 => {
192- vortex_bail ! ( "i256 decimals cannot be converted to Arrow i128 decimal" )
193- }
217+ DecimalValueType :: I256 => array
218+ . buffer :: < vortex_scalar:: i256 > ( )
219+ . into_iter ( )
220+ . map ( |x| {
221+ x. to_i128 ( )
222+ . ok_or_else ( || vortex_err ! ( "i256 to i128 narrowing cannot be done safely" ) )
223+ } )
224+ . try_collect ( ) ?,
194225 _ => vortex_bail ! ( "unknown value type {:?}" , array. values_type( ) ) ,
195226 } ;
196227 Ok ( Arc :: new (
@@ -206,10 +237,14 @@ fn to_arrow_decimal256(array: DecimalArray) -> VortexResult<ArrowArrayRef> {
206237 let null_buffer = array. validity_mask ( ) ?. to_null_buffer ( ) ;
207238 let buffer: Buffer < i256 > = match array. values_type ( ) {
208239 DecimalValueType :: I8 => array. buffer :: < i8 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
209- DecimalValueType :: I16 => array. buffer :: < i8 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
210- DecimalValueType :: I32 => array. buffer :: < i8 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
211- DecimalValueType :: I64 => array. buffer :: < i8 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
212- DecimalValueType :: I128 => array. buffer :: < i8 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
240+ DecimalValueType :: I16 => array. buffer :: < i16 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
241+ DecimalValueType :: I32 => array. buffer :: < i32 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
242+ DecimalValueType :: I64 => array. buffer :: < i64 > ( ) . into_iter ( ) . map ( |x| x. as_ ( ) ) . collect ( ) ,
243+ DecimalValueType :: I128 => array
244+ . buffer :: < i128 > ( )
245+ . into_iter ( )
246+ . map ( |x| vortex_scalar:: i256:: from_i128 ( x) . into ( ) )
247+ . collect ( ) ,
213248 DecimalValueType :: I256 => Buffer :: < i256 > :: from_byte_buffer ( array. byte_buffer ( ) ) ,
214249 _ => vortex_bail ! ( "unknown type {:?}" , array. values_type( ) ) ,
215250 } ;
@@ -334,15 +369,19 @@ where
334369
335370#[ cfg( test) ]
336371mod tests {
337- use arrow_array:: Decimal128Array ;
372+ use arrow_array:: { Array , Decimal128Array , Decimal256Array } ;
373+ use arrow_buffer:: i256;
338374 use arrow_schema:: { DataType , Field } ;
375+ use rstest:: rstest;
339376 use vortex_buffer:: buffer;
340377 use vortex_dtype:: { DecimalDType , FieldNames } ;
378+ use vortex_scalar:: NativeDecimalType ;
341379
342380 use crate :: IntoArray ;
343381 use crate :: arrays:: { DecimalArray , PrimitiveArray , StructArray } ;
344382 use crate :: arrow:: IntoArrowArray ;
345383 use crate :: arrow:: compute:: to_arrow;
384+ use crate :: builders:: { ArrayBuilder , DecimalBuilder } ;
346385 use crate :: validity:: Validity ;
347386
348387 #[ test]
@@ -398,4 +437,54 @@ mod tests {
398437
399438 assert ! ( struct_a. into_array( ) . into_arrow( & arrow_dt) . is_err( ) ) ;
400439 }
440+
441+ #[ rstest]
442+ #[ case( 0i8 ) ]
443+ #[ case( 0i16 ) ]
444+ #[ case( 0i32 ) ]
445+ #[ case( 0i64 ) ]
446+ #[ case( 0i128 ) ]
447+ #[ case( vortex_scalar:: i256:: ZERO ) ]
448+ fn to_arrow_decimal128 < T : NativeDecimalType > ( #[ case] _decimal_type : T ) {
449+ let mut decimal = DecimalBuilder :: new :: < T > ( 2 , 1 , false . into ( ) ) ;
450+ decimal. append_value ( 10 ) ;
451+ decimal. append_value ( 11 ) ;
452+ decimal. append_value ( 12 ) ;
453+
454+ let decimal = decimal. finish ( ) ;
455+
456+ let arrow_array = decimal. into_arrow ( & DataType :: Decimal128 ( 2 , 1 ) ) . unwrap ( ) ;
457+ let arrow_decimal = arrow_array
458+ . as_any ( )
459+ . downcast_ref :: < Decimal128Array > ( )
460+ . unwrap ( ) ;
461+ assert_eq ! ( arrow_decimal. value( 0 ) , 10 ) ;
462+ assert_eq ! ( arrow_decimal. value( 1 ) , 11 ) ;
463+ assert_eq ! ( arrow_decimal. value( 2 ) , 12 ) ;
464+ }
465+
466+ #[ rstest]
467+ #[ case( 0i8 ) ]
468+ #[ case( 0i16 ) ]
469+ #[ case( 0i32 ) ]
470+ #[ case( 0i64 ) ]
471+ #[ case( 0i128 ) ]
472+ #[ case( vortex_scalar:: i256:: ZERO ) ]
473+ fn to_arrow_decimal256 < T : NativeDecimalType > ( #[ case] _decimal_type : T ) {
474+ let mut decimal = DecimalBuilder :: new :: < T > ( 2 , 1 , false . into ( ) ) ;
475+ decimal. append_value ( 10 ) ;
476+ decimal. append_value ( 11 ) ;
477+ decimal. append_value ( 12 ) ;
478+
479+ let decimal = decimal. finish ( ) ;
480+
481+ let arrow_array = decimal. into_arrow ( & DataType :: Decimal256 ( 2 , 1 ) ) . unwrap ( ) ;
482+ let arrow_decimal = arrow_array
483+ . as_any ( )
484+ . downcast_ref :: < Decimal256Array > ( )
485+ . unwrap ( ) ;
486+ assert_eq ! ( arrow_decimal. value( 0 ) , i256:: from_i128( 10 ) ) ;
487+ assert_eq ! ( arrow_decimal. value( 1 ) , i256:: from_i128( 11 ) ) ;
488+ assert_eq ! ( arrow_decimal. value( 2 ) , i256:: from_i128( 12 ) ) ;
489+ }
401490}
0 commit comments