11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use std:: iter;
45use std:: ops:: Deref ;
56
67use num_traits:: AsPrimitive ;
78use vortex_buffer:: Buffer ;
89use vortex_dtype:: match_each_integer_ptype;
910use vortex_error:: VortexResult ;
11+ use vortex_mask:: AllOr ;
12+ use vortex_mask:: Mask ;
1013use vortex_vector:: binaryview:: BinaryView ;
1114
1215use crate :: Array ;
@@ -23,16 +26,16 @@ use crate::vtable::ValidityHelper;
2326/// Take involves creating a new array that references the old array, just with the given set of views.
2427impl TakeKernel for VarBinViewVTable {
2528 fn take ( & self , array : & VarBinViewArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
26- // Compute the new validity
27-
28- // This is valid since all elements (of all arrays) even null values must be inside
29- // min-max valid range.
29+ // Compute the new validity.
3030 let validity = array. validity ( ) . take ( indices) ?;
3131 let indices = indices. to_primitive ( ) ;
3232
3333 let views_buffer = match_each_integer_ptype ! ( indices. ptype( ) , |I | {
34- // This is valid since all elements even null values are inside the min-max valid range.
35- take_views( array. views( ) , indices. as_slice:: <I >( ) )
34+ take_views(
35+ array. views( ) ,
36+ indices. as_slice:: <I >( ) ,
37+ & indices. validity_mask( ) ,
38+ )
3639 } ) ;
3740
3841 // SAFETY: taking all components at same indices maintains invariants
@@ -55,15 +58,36 @@ register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
5558fn take_views < I : AsPrimitive < usize > > (
5659 views : & Buffer < BinaryView > ,
5760 indices : & [ I ] ,
61+ mask : & Mask ,
5862) -> Buffer < BinaryView > {
5963 // NOTE(ngates): this deref is not actually trivial, so we run it once.
6064 let views_ref = views. deref ( ) ;
61- Buffer :: < BinaryView > :: from_trusted_len_iter ( indices. iter ( ) . map ( |i| views_ref[ i. as_ ( ) ] ) )
65+ // We do not use iter_bools directly, since the resulting dyn iterator cannot
66+ // implement TrustedLen.
67+ match mask. bit_buffer ( ) {
68+ AllOr :: All => {
69+ Buffer :: < BinaryView > :: from_trusted_len_iter ( indices. iter ( ) . map ( |i| views_ref[ i. as_ ( ) ] ) )
70+ }
71+ AllOr :: None => Buffer :: < BinaryView > :: from_trusted_len_iter ( iter:: repeat_n (
72+ BinaryView :: default ( ) ,
73+ indices. len ( ) ,
74+ ) ) ,
75+ AllOr :: Some ( buffer) => Buffer :: < BinaryView > :: from_trusted_len_iter (
76+ buffer. iter ( ) . zip ( indices. iter ( ) ) . map ( |( valid, idx) | {
77+ if valid {
78+ views_ref[ idx. as_ ( ) ]
79+ } else {
80+ BinaryView :: default ( )
81+ }
82+ } ) ,
83+ ) ,
84+ }
6285}
6386
6487#[ cfg( test) ]
6588mod tests {
6689 use rstest:: rstest;
90+ use vortex_buffer:: BitBuffer ;
6791 use vortex_buffer:: buffer;
6892 use vortex_dtype:: DType ;
6993 use vortex_dtype:: Nullability :: NonNullable ;
@@ -76,6 +100,7 @@ mod tests {
76100 use crate :: canonical:: ToCanonical ;
77101 use crate :: compute:: conformance:: take:: test_take_conformance;
78102 use crate :: compute:: take;
103+ use crate :: validity:: Validity ;
79104
80105 #[ test]
81106 fn take_nullable ( ) {
@@ -103,11 +128,13 @@ mod tests {
103128 fn take_nullable_indices ( ) {
104129 let arr = VarBinViewArray :: from_iter ( [ "one" , "two" ] . map ( Some ) , DType :: Utf8 ( NonNullable ) ) ;
105130
106- let taken = take (
107- arr. as_ref ( ) ,
108- PrimitiveArray :: from_option_iter ( vec ! [ Some ( 1 ) , None ] ) . as_ref ( ) ,
109- )
110- . unwrap ( ) ;
131+ let indices = PrimitiveArray :: new (
132+ // Verify that garbage values at NULL indices are ignored.
133+ buffer ! [ 1u64 , 999 ] ,
134+ Validity :: from ( BitBuffer :: from ( vec ! [ true , false ] ) ) ,
135+ ) ;
136+
137+ let taken = take ( arr. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
111138
112139 assert ! ( taken. dtype( ) . is_nullable( ) ) ;
113140 assert_eq ! (
0 commit comments