@@ -46,7 +46,7 @@ use crate::{Array, ArrayRef, IntoArray, ToCanonical};
4646/// assert_eq!(to_vec, vec![false, true, false]);
4747/// ```
4848pub fn list_contains ( array : & dyn Array , value : Scalar ) -> VortexResult < ArrayRef > {
49- let DType :: List ( elem_dtype, _nullability ) = array. dtype ( ) else {
49+ let DType :: List ( elem_dtype, nullability ) = array. dtype ( ) else {
5050 vortex_bail ! ( "Array must be of List type" ) ;
5151 } ;
5252 if & * * elem_dtype != value. dtype ( ) {
@@ -68,20 +68,45 @@ pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef>
6868 }
6969
7070 let elems = list_array. elements ( ) ;
71- let ends = list_array. offsets ( ) . to_primitive ( ) ?;
71+ if elems. is_empty ( ) {
72+ // Must return false when a list is empty (but valid), or null when the list itself is null.
73+ return list_false_or_null ( & list_array) ;
74+ }
7275
7376 let rhs = ConstantArray :: new ( value, elems. len ( ) ) ;
7477 let matching_elements = compare ( elems, rhs. as_ref ( ) , Operator :: Eq ) ?;
7578 let matches = matching_elements. to_bool ( ) ?;
7679
7780 // Fast path: no elements match.
7881 if let Some ( pred) = matches. as_constant ( ) {
79- if matches ! ( pred. as_bool( ) . value( ) , None | Some ( false ) ) {
80- // TODO(aduffy): how do we handle null?
81- return Ok ( ConstantArray :: new :: < bool > ( false , list_array. len ( ) ) . into_array ( ) ) ;
82- }
82+ return match pred. as_bool ( ) . value ( ) {
83+ // All comparisons are invalid (result in `null`), and search is not null because
84+ // we already checked for null above.
85+ None => {
86+ assert ! (
87+ !rhs. scalar( ) . is_null( ) ,
88+ "Search value must not be null here"
89+ ) ;
90+ // False, unless the list itself is null in which case we return null.
91+ list_false_or_null ( & list_array)
92+ }
93+ // No elements match, and all comparisons are valid (result in `false`).
94+ Some ( false ) => {
95+ // False, but match the nullability to the input list array.
96+ Ok (
97+ ConstantArray :: new ( Scalar :: bool ( false , * nullability) , list_array. len ( ) )
98+ . into_array ( ) ,
99+ )
100+ }
101+ // All elements match, and all comparisons are valid (result in `true`).
102+ Some ( true ) => {
103+ // True, unless the list itself is empty or NULL.
104+ list_is_not_empty ( & list_array)
105+ }
106+ } ;
83107 }
84108
109+ let ends = list_array. offsets ( ) . to_primitive ( ) ?;
85110 match_each_integer_ptype ! ( ends. ptype( ) , |T | {
86111 Ok ( reduce_with_ends(
87112 ends. as_slice:: <T >( ) ,
@@ -99,28 +124,15 @@ fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
99124 // Check element validity. We need to intersect
100125 match elems. validity_mask ( ) ? {
101126 // No NULL elements
102- Mask :: AllTrue ( _) => match list_array. validity ( ) {
103- Validity :: NonNullable => {
104- Ok ( ConstantArray :: new :: < bool > ( false , list_array. len ( ) ) . into_array ( ) )
105- }
106- Validity :: AllValid => Ok ( ConstantArray :: new (
107- Scalar :: bool ( true , Nullability :: Nullable ) ,
108- list_array. len ( ) ,
109- )
110- . into_array ( ) ) ,
111- Validity :: AllInvalid => Ok ( ConstantArray :: new (
112- Scalar :: null ( DType :: Bool ( Nullability :: Nullable ) ) ,
113- list_array. len ( ) ,
114- )
115- . into_array ( ) ) ,
116- Validity :: Array ( list_mask) => {
117- // Create a new bool array with false, and the provided nulls
118- let buffer = BooleanBuffer :: new_unset ( list_array. len ( ) ) ;
119- Ok ( BoolArray :: new ( buffer, Validity :: Array ( list_mask. clone ( ) ) ) . into_array ( ) )
120- }
121- } ,
122- // All null elements
123- Mask :: AllFalse ( _) => Ok ( ConstantArray :: new :: < bool > ( true , list_array. len ( ) ) . into_array ( ) ) ,
127+ Mask :: AllTrue ( _) => {
128+ // False, unless the list itself is NULL.
129+ list_false_or_null ( list_array)
130+ }
131+ // All NULL elements.
132+ Mask :: AllFalse ( _) => {
133+ // True, unless the list itself is empty or NULL.
134+ list_is_not_empty ( list_array)
135+ }
124136 Mask :: Values ( mask) => {
125137 let nulls = invert ( & mask. into_array ( ) ) ?. to_bool ( ) ?;
126138 let ends = list_array. offsets ( ) . to_primitive ( ) ?;
@@ -135,6 +147,58 @@ fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
135147 }
136148}
137149
150+ /// Returns a `Bool` array with `false` for lists that are valid,
151+ /// or `NULL` if the list itself is null.
152+ fn list_false_or_null ( list_array : & ListArray ) -> VortexResult < ArrayRef > {
153+ match list_array. validity ( ) {
154+ Validity :: NonNullable => {
155+ // All false.
156+ Ok ( ConstantArray :: new :: < bool > ( false , list_array. len ( ) ) . into_array ( ) )
157+ }
158+ Validity :: AllValid => {
159+ // All false, but nullable.
160+ Ok (
161+ ConstantArray :: new ( Scalar :: bool ( false , Nullability :: Nullable ) , list_array. len ( ) )
162+ . into_array ( ) ,
163+ )
164+ }
165+ Validity :: AllInvalid => {
166+ // All nulls, must be nullable result.
167+ Ok ( ConstantArray :: new (
168+ Scalar :: null ( DType :: Bool ( Nullability :: Nullable ) ) ,
169+ list_array. len ( ) ,
170+ )
171+ . into_array ( ) )
172+ }
173+ Validity :: Array ( validity_array) => {
174+ // Create a new bool array with false, and the provided nulls
175+ let buffer = BooleanBuffer :: new_unset ( list_array. len ( ) ) ;
176+ Ok ( BoolArray :: new ( buffer, Validity :: Array ( validity_array. clone ( ) ) ) . into_array ( ) )
177+ }
178+ }
179+ }
180+
181+ /// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
182+ /// or `NULL` if the list itself is null.
183+ fn list_is_not_empty ( list_array : & ListArray ) -> VortexResult < ArrayRef > {
184+ // Short-circuit for all invalid.
185+ if matches ! ( list_array. validity( ) , Validity :: AllInvalid ) {
186+ return Ok ( ConstantArray :: new (
187+ Scalar :: null ( DType :: Bool ( Nullability :: Nullable ) ) ,
188+ list_array. len ( ) ,
189+ )
190+ . into_array ( ) ) ;
191+ }
192+
193+ let offsets = list_array. offsets ( ) . to_primitive ( ) ?;
194+ let buffer = match_each_integer_ptype ! ( offsets. ptype( ) , |T | {
195+ element_is_not_empty( offsets. as_slice:: <T >( ) )
196+ } ) ;
197+
198+ // Copy over the validity mask from the input.
199+ Ok ( BoolArray :: new ( buffer, list_array. validity ( ) . clone ( ) ) . into_array ( ) )
200+ }
201+
138202// Reduce each boolean values into a Mask that indicates which elements in the
139203// ListArray contain the matching value.
140204fn reduce_with_ends < T : NativePType + AsPrimitive < usize > > (
@@ -203,6 +267,10 @@ fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
203267 . collect ( )
204268}
205269
270+ fn element_is_not_empty < T : NativePType > ( values : & [ T ] ) -> BooleanBuffer {
271+ BooleanBuffer :: from_iter ( values. windows ( 2 ) . map ( |window| window[ 1 ] != window[ 0 ] ) )
272+ }
273+
206274#[ cfg( test) ]
207275mod tests {
208276 use std:: sync:: Arc ;
@@ -285,6 +353,18 @@ mod tests {
285353 Some ( "a" ) ,
286354 bool_array( vec![ false , false , false ] , None )
287355 ) ]
356+ // Case 6: list(utf8?) with empty + NULL elements and NULL search
357+ #[ case(
358+ null_strings( vec![ vec![ ] , vec![ None , None ] , vec![ None , None , None ] ] ) ,
359+ None ,
360+ bool_array( vec![ false , true , true ] , None )
361+ ) ]
362+ // Case 7: list(utf8?) with empty + NULL elements and search scalar
363+ #[ case(
364+ null_strings( vec![ vec![ ] , vec![ None , None ] , vec![ None , None , None ] ] ) ,
365+ Some ( "a" ) ,
366+ bool_array( vec![ false , false , false ] , None )
367+ ) ]
288368 fn test_contains_nullable (
289369 #[ case] list_array : ArrayRef ,
290370 #[ case] value : Option < & str > ,
@@ -328,4 +408,25 @@ mod tests {
328408 vec![ true , true ] ,
329409 ) ;
330410 }
411+
412+ #[ test]
413+ fn test_all_nulls ( ) {
414+ let list_array = ConstantArray :: new (
415+ Scalar :: null ( DType :: List (
416+ Arc :: new ( DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ) ,
417+ Nullability :: Nullable ,
418+ ) ) ,
419+ 5 ,
420+ )
421+ . into_array ( ) ;
422+
423+ let contains = list_contains ( & list_array, 2i32 . into ( ) ) . unwrap ( ) ;
424+ assert ! ( contains. is:: <ConstantVTable >( ) , "Expected constant result" ) ;
425+
426+ assert_eq ! ( contains. len( ) , 5 ) ;
427+ assert_eq ! (
428+ contains. to_bool( ) . unwrap( ) . validity( ) ,
429+ & Validity :: AllInvalid
430+ ) ;
431+ }
331432}
0 commit comments