@@ -11,13 +11,22 @@ impl TakeFn<&ConstantArray> for ConstantEncoding {
1111 fn take ( & self , array : & ConstantArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
1212 match indices. validity_mask ( ) ?. boolean_buffer ( ) {
1313 AllOr :: All => {
14- Ok ( ConstantArray :: new ( array. scalar ( ) . clone ( ) , indices. len ( ) ) . into_array ( ) )
14+ let nullability = array. dtype ( ) . nullability ( ) | indices. dtype ( ) . nullability ( ) ;
15+ let scalar = Scalar :: new (
16+ array. scalar ( ) . dtype ( ) . with_nullability ( nullability) ,
17+ array. scalar ( ) . value ( ) . clone ( ) ,
18+ ) ;
19+ Ok ( ConstantArray :: new ( scalar, indices. len ( ) ) . into_array ( ) )
20+ }
21+ AllOr :: None => {
22+ Ok ( ConstantArray :: new (
23+ Scalar :: null ( array. dtype ( ) . with_nullability (
24+ array. dtype ( ) . nullability ( ) | indices. dtype ( ) . nullability ( ) ,
25+ ) ) ,
26+ indices. len ( ) ,
27+ )
28+ . into_array ( ) )
1529 }
16- AllOr :: None => Ok ( ConstantArray :: new (
17- Scalar :: null ( array. dtype ( ) . clone ( ) ) ,
18- indices. len ( ) ,
19- )
20- . into_array ( ) ) ,
2130 AllOr :: Some ( v) => {
2231 let arr = ConstantArray :: new ( array. scalar ( ) . clone ( ) , indices. len ( ) ) . into_array ( ) ;
2332
@@ -38,6 +47,7 @@ impl TakeFn<&ConstantArray> for ConstantEncoding {
3847#[ cfg( test) ]
3948mod tests {
4049 use vortex_buffer:: buffer;
50+ use vortex_dtype:: Nullability ;
4151 use vortex_mask:: AllOr ;
4252
4353 use crate :: arrays:: { ConstantArray , PrimitiveArray } ;
@@ -58,6 +68,10 @@ mod tests {
5868 )
5969 . unwrap ( ) ;
6070 let valid_indices: & [ usize ] = & [ 1usize ] ;
71+ assert_eq ! (
72+ & array. dtype( ) . with_nullability( Nullability :: Nullable ) ,
73+ taken. dtype( )
74+ ) ;
6175 assert_eq ! (
6276 taken. to_primitive( ) . unwrap( ) . as_slice:: <i32 >( ) ,
6377 & [ 42 , 42 , 42 ]
@@ -67,4 +81,23 @@ mod tests {
6781 AllOr :: Some ( valid_indices)
6882 ) ;
6983 }
84+
85+ #[ test]
86+ fn take_all_valid_indices ( ) {
87+ let array = ConstantArray :: new ( 42 , 10 ) . to_array ( ) ;
88+ let taken = take (
89+ & array,
90+ & PrimitiveArray :: new ( buffer ! [ 0 , 5 , 7 ] , Validity :: AllValid ) . into_array ( ) ,
91+ )
92+ . unwrap ( ) ;
93+ assert_eq ! (
94+ & array. dtype( ) . with_nullability( Nullability :: Nullable ) ,
95+ taken. dtype( )
96+ ) ;
97+ assert_eq ! (
98+ taken. to_primitive( ) . unwrap( ) . as_slice:: <i32 >( ) ,
99+ & [ 42 , 42 , 42 ]
100+ ) ;
101+ assert_eq ! ( taken. validity_mask( ) . unwrap( ) . indices( ) , AllOr :: All ) ;
102+ }
70103}
0 commit comments