11use vortex_array:: compute:: { scalar_at, ScalarAtFn } ;
2+ use vortex_array:: ArrayDType ;
23use vortex_error:: VortexResult ;
34use vortex_scalar:: Scalar ;
45
@@ -26,11 +27,11 @@ impl ScalarAtFn<ALPRDArray> for ALPRDEncoding {
2627 if array. is_f32 ( ) {
2728 let right: u32 = scalar_at ( array. right_parts ( ) , index) ?. try_into ( ) ?;
2829 let packed = f32:: from_bits ( ( left as u32 ) << array. right_bit_width ( ) | right) ;
29- Ok ( packed. into ( ) )
30+ Ok ( Scalar :: primitive ( packed, array . dtype ( ) . nullability ( ) ) )
3031 } else {
3132 let right: u64 = scalar_at ( array. right_parts ( ) , index) ?. try_into ( ) ?;
3233 let packed = f64:: from_bits ( ( ( left as u64 ) << array. right_bit_width ( ) ) | right) ;
33- Ok ( packed. into ( ) )
34+ Ok ( Scalar :: primitive ( packed, array . dtype ( ) . nullability ( ) ) )
3435 }
3536 }
3637}
@@ -40,6 +41,7 @@ mod test {
4041 use rstest:: rstest;
4142 use vortex_array:: array:: PrimitiveArray ;
4243 use vortex_array:: compute:: scalar_at;
44+ use vortex_dtype:: Nullability ;
4345 use vortex_scalar:: Scalar ;
4446
4547 use crate :: { ALPRDFloat , RDEncoder } ;
@@ -65,4 +67,32 @@ mod test {
6567 // The right value hits the left_part_exceptions
6668 assert_eq ! ( scalar_at( encoded. as_ref( ) , 2 ) . unwrap( ) , outlier. into( ) ) ;
6769 }
70+
71+ #[ test]
72+ fn nullable_scalar_at ( ) {
73+ let a = 0.1f64 ;
74+ let b = 0.2f64 ;
75+ let outlier = 3e100f64 ;
76+ let array = PrimitiveArray :: from_option_iter ( [ Some ( a) , Some ( b) , Some ( outlier) ] ) ;
77+ let encoded = RDEncoder :: new ( & [ a, b] ) . encode ( & array) ;
78+
79+ // Make sure that we're testing the exception pathway.
80+ assert ! ( encoded. left_parts_patches( ) . is_some( ) ) ;
81+
82+ // The first two values need no patching
83+ assert_eq ! (
84+ scalar_at( encoded. as_ref( ) , 0 ) . unwrap( ) ,
85+ Scalar :: primitive( a, Nullability :: Nullable )
86+ ) ;
87+ assert_eq ! (
88+ scalar_at( encoded. as_ref( ) , 1 ) . unwrap( ) ,
89+ Scalar :: primitive( b, Nullability :: Nullable )
90+ ) ;
91+
92+ // The right value hits the left_part_exceptions
93+ assert_eq ! (
94+ scalar_at( encoded. as_ref( ) , 2 ) . unwrap( ) ,
95+ Scalar :: primitive( outlier, Nullability :: Nullable )
96+ ) ;
97+ }
6898}
0 commit comments