@@ -20,13 +20,14 @@ use vortex_error::vortex_err;
2020use vortex_mask:: Mask ;
2121use vortex_vector:: BoolDatum ;
2222use vortex_vector:: Datum ;
23- use vortex_vector:: ScalarOps ;
2423use vortex_vector:: Vector ;
25- use vortex_vector:: VectorMutOps ;
2624use vortex_vector:: VectorOps ;
25+ use vortex_vector:: bool:: BoolScalar ;
2726use vortex_vector:: bool:: BoolVector ;
2827use vortex_vector:: listview:: ListViewScalar ;
2928use vortex_vector:: listview:: ListViewVector ;
29+ use vortex_vector:: match_each_pvector;
30+ use vortex_vector:: primitive:: PScalar ;
3031use vortex_vector:: primitive:: PVector ;
3132
3233use crate :: ArrayRef ;
@@ -128,30 +129,34 @@ impl VTable for ListContains {
128129 . try_into ( )
129130 . map_err ( |_| vortex_err ! ( "Wrong number of arguments for ListContains expression" ) ) ?;
130131
131- let matches = match ( lhs. as_scalar ( ) . is_some ( ) , rhs. as_scalar ( ) . is_some ( ) ) {
132+ match ( lhs. as_scalar ( ) . is_some ( ) , rhs. as_scalar ( ) . is_some ( ) ) {
132133 ( true , true ) => {
134+ // Early return with Scalar to avoid allocating BitBuffer.
133135 let list = lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ;
134136 let needle = rhs. into_scalar ( ) . vortex_expect ( "scalar" ) ;
135- // Convert the needle scalar to a vector with row_count
136- // elements and reuse constant_list_scalar_contains
137- let needle_vector = needle. repeat ( args. row_count ) . freeze ( ) ;
138- constant_list_scalar_contains ( list, needle_vector) ?
137+ let found = list_contains_scalar_scalar ( & list, & needle) ?;
138+ Ok ( Datum :: Scalar ( BoolScalar :: new ( Some ( found) ) . into ( ) ) )
139+ }
140+ ( true , false ) => {
141+ let matches = constant_list_scalar_contains (
142+ lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
143+ rhs. into_vector ( ) . vortex_expect ( "vector" ) ,
144+ ) ?;
145+ Ok ( Datum :: Vector ( matches. into ( ) ) )
146+ }
147+ ( false , true ) => {
148+ let matches = list_contains_scalar (
149+ lhs. unwrap_into_vector ( args. row_count ) . into_list ( ) ,
150+ rhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
151+ ) ?;
152+ Ok ( Datum :: Vector ( matches. into ( ) ) )
139153 }
140- ( true , false ) => constant_list_scalar_contains (
141- lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
142- rhs. into_vector ( ) . vortex_expect ( "vector" ) ,
143- ) ?,
144- ( false , true ) => list_contains_scalar (
145- lhs. unwrap_into_vector ( args. row_count ) . into_list ( ) ,
146- rhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
147- ) ?,
148154 ( false , false ) => {
149155 vortex_bail ! (
150156 "ListContains currently only supports constant needle (RHS) or constant list (LHS)"
151157 )
152158 }
153- } ;
154- Ok ( Datum :: Vector ( matches. into ( ) ) )
159+ }
155160 }
156161
157162 fn stat_falsification (
@@ -330,6 +335,32 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
330335 Ok ( result)
331336}
332337
338+ /// Used when both needle and list are scalars.
339+ fn list_contains_scalar_scalar (
340+ list : & ListViewScalar ,
341+ needle : & vortex_vector:: Scalar ,
342+ ) -> VortexResult < bool > {
343+ assert ! ( false ) ;
344+ let elements = list. value ( ) . elements ( ) ;
345+
346+ // Downcast to `PVector` and access slice directly to avoid `scalar_at` overhead.
347+ let found = if let Vector :: Primitive ( prim) = & * * elements {
348+ match_each_pvector ! ( prim, |pvec| {
349+ let slice: & [ _] = pvec. as_ref( ) ;
350+ let validity = pvec. validity( ) ;
351+ slice
352+ . iter( )
353+ . enumerate( )
354+ . any( |( i, & elem) | needle == & PScalar :: new( Some ( elem) ) . into( ) && validity. value( i) )
355+ } )
356+ } else {
357+ // Fallback for non-primitive vectors
358+ ( 0 ..elements. len ( ) ) . any ( |i| needle == & elements. scalar_at ( i) )
359+ } ;
360+
361+ Ok ( found)
362+ }
363+
333364/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
334365/// [`BoolArray`] of matches on the child elements array.
335366///
@@ -366,6 +397,7 @@ where
366397mod tests {
367398 use std:: sync:: Arc ;
368399
400+ use rstest:: rstest;
369401 use vortex_buffer:: BitBuffer ;
370402 use vortex_dtype:: DType ;
371403 use vortex_dtype:: Field ;
@@ -556,4 +588,57 @@ mod tests {
556588 let expr2 = list_contains ( root ( ) , lit ( 42 ) ) ;
557589 assert_eq ! ( expr2. to_string( ) , "contains($, 42i32)" ) ;
558590 }
591+
592+ #[ rstest]
593+ #[ case( vec![ 1i32 , 2i32 , 3i32 ] , 1i32 , true , "first_element" ) ]
594+ #[ case( vec![ 1i32 , 2i32 , 3i32 ] , 2i32 , true , "middle_element" ) ]
595+ #[ case( vec![ 1i32 , 2i32 , 3i32 ] , 3i32 , true , "last_element" ) ]
596+ fn test_scalar_scalar_found (
597+ #[ case] list_values : Vec < i32 > ,
598+ #[ case] needle : i32 ,
599+ #[ case] expected : bool ,
600+ #[ case] _description : & str ,
601+ ) {
602+ let expr = list_contains (
603+ lit ( Scalar :: list (
604+ Arc :: new ( DType :: Primitive ( I32 , Nullability :: NonNullable ) ) ,
605+ list_values
606+ . into_iter ( )
607+ . map ( |v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
608+ . collect ( ) ,
609+ Nullability :: NonNullable ,
610+ ) ) ,
611+ lit ( needle) ,
612+ ) ;
613+ let arr = test_array ( ) ;
614+ let result = expr. evaluate ( & arr) . unwrap ( ) ;
615+ assert_eq ! (
616+ result. scalar_at( 0 ) ,
617+ Scalar :: bool ( expected, Nullability :: Nullable )
618+ ) ;
619+ }
620+
621+ #[ rstest]
622+ #[ case( 0i32 , false , "empty_list" ) ]
623+ #[ case( 1i32 , false , "empty_list_seek_one" ) ]
624+ #[ case( 100i32 , false , "empty_list_seek_large" ) ]
625+ fn test_scalar_scalar_not_found (
626+ #[ case] needle : i32 ,
627+ #[ case] expected : bool ,
628+ #[ case] _description : & str ,
629+ ) {
630+ let expr = list_contains (
631+ lit ( Scalar :: list_empty (
632+ Arc :: new ( DType :: Primitive ( I32 , Nullability :: NonNullable ) ) ,
633+ Nullability :: NonNullable ,
634+ ) ) ,
635+ lit ( needle) ,
636+ ) ;
637+ let arr = test_array ( ) ;
638+ let result = expr. evaluate ( & arr) . unwrap ( ) ;
639+ assert_eq ! (
640+ result. scalar_at( 0 ) ,
641+ Scalar :: bool ( expected, Nullability :: Nullable )
642+ ) ;
643+ }
559644}
0 commit comments