@@ -20,13 +20,14 @@ use vortex_error::vortex_err;
2020use vortex_mask:: Mask ;
2121use vortex_vector:: BoolDatum ;
2222use vortex_vector:: Datum ;
23- use vortex_vector:: ScalarOps ;
2423use vortex_vector:: Vector ;
25- use vortex_vector:: VectorMutOps ;
2624use vortex_vector:: VectorOps ;
25+ use vortex_vector:: bool:: BoolScalar ;
2726use vortex_vector:: bool:: BoolVector ;
2827use vortex_vector:: listview:: ListViewScalar ;
2928use vortex_vector:: listview:: ListViewVector ;
29+ use vortex_vector:: match_each_pvector;
30+ use vortex_vector:: primitive:: PScalar ;
3031use vortex_vector:: primitive:: PVector ;
3132
3233use crate :: ArrayRef ;
@@ -128,30 +129,34 @@ impl VTable for ListContains {
128129 . try_into ( )
129130 . map_err ( |_| vortex_err ! ( "Wrong number of arguments for ListContains expression" ) ) ?;
130131
131- let matches = match ( lhs. as_scalar ( ) . is_some ( ) , rhs. as_scalar ( ) . is_some ( ) ) {
132+ match ( lhs. as_scalar ( ) . is_some ( ) , rhs. as_scalar ( ) . is_some ( ) ) {
132133 ( true , true ) => {
134+ // Early return with Scalar to avoid allocating BitBuffer.
133135 let list = lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ;
134136 let needle = rhs. into_scalar ( ) . vortex_expect ( "scalar" ) ;
135- // Convert the needle scalar to a vector with row_count
136- // elements and reuse constant_list_scalar_contains
137- let needle_vector = needle. repeat ( args. row_count ) . freeze ( ) ;
138- constant_list_scalar_contains ( list, needle_vector) ?
137+ let found = list_contains_scalar_scalar ( & list, & needle) ?;
138+ Ok ( Datum :: Scalar ( BoolScalar :: new ( Some ( found) ) . into ( ) ) )
139+ }
140+ ( true , false ) => {
141+ let matches = constant_list_scalar_contains (
142+ lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
143+ rhs. into_vector ( ) . vortex_expect ( "vector" ) ,
144+ ) ?;
145+ Ok ( Datum :: Vector ( matches. into ( ) ) )
146+ }
147+ ( false , true ) => {
148+ let matches = list_contains_scalar (
149+ lhs. unwrap_into_vector ( args. row_count ) . into_list ( ) ,
150+ rhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
151+ ) ?;
152+ Ok ( Datum :: Vector ( matches. into ( ) ) )
139153 }
140- ( true , false ) => constant_list_scalar_contains (
141- lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
142- rhs. into_vector ( ) . vortex_expect ( "vector" ) ,
143- ) ?,
144- ( false , true ) => list_contains_scalar (
145- lhs. unwrap_into_vector ( args. row_count ) . into_list ( ) ,
146- rhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
147- ) ?,
148154 ( false , false ) => {
149155 vortex_bail ! (
150156 "ListContains currently only supports constant needle (RHS) or constant list (LHS)"
151157 )
152158 }
153- } ;
154- Ok ( Datum :: Vector ( matches. into ( ) ) )
159+ }
155160 }
156161
157162 fn stat_falsification (
@@ -330,6 +335,31 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
330335 Ok ( result)
331336}
332337
338+ /// Used when both needle and list are scalars.
339+ fn list_contains_scalar_scalar (
340+ list : & ListViewScalar ,
341+ needle : & vortex_vector:: Scalar ,
342+ ) -> VortexResult < bool > {
343+ let elements = list. value ( ) . elements ( ) ;
344+
345+ // Downcast to `PVector` and access slice directly to avoid `scalar_at` overhead.
346+ let found = if let Vector :: Primitive ( prim) = & * * elements {
347+ match_each_pvector ! ( prim, |pvec| {
348+ let slice: & [ _] = pvec. as_ref( ) ;
349+ let validity = pvec. validity( ) ;
350+ slice
351+ . iter( )
352+ . enumerate( )
353+ . any( |( i, & elem) | needle == & PScalar :: new( Some ( elem) ) . into( ) && validity. value( i) )
354+ } )
355+ } else {
356+ // Fallback for non-primitive vectors
357+ ( 0 ..elements. len ( ) ) . any ( |i| needle == & elements. scalar_at ( i) )
358+ } ;
359+
360+ Ok ( found)
361+ }
362+
333363/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
334364/// [`BoolArray`] of matches on the child elements array.
335365///
@@ -366,6 +396,7 @@ where
366396mod tests {
367397 use std:: sync:: Arc ;
368398
399+ use rstest:: rstest;
369400 use vortex_buffer:: BitBuffer ;
370401 use vortex_dtype:: DType ;
371402 use vortex_dtype:: Field ;
@@ -556,4 +587,57 @@ mod tests {
556587 let expr2 = list_contains ( root ( ) , lit ( 42 ) ) ;
557588 assert_eq ! ( expr2. to_string( ) , "contains($, 42i32)" ) ;
558589 }
590+
591+ #[ rstest]
592+ #[ case( vec![ 1i32 , 2i32 , 3i32 ] , 1i32 , true , "first_element" ) ]
593+ #[ case( vec![ 1i32 , 2i32 , 3i32 ] , 2i32 , true , "middle_element" ) ]
594+ #[ case( vec![ 1i32 , 2i32 , 3i32 ] , 3i32 , true , "last_element" ) ]
595+ fn test_scalar_scalar_found (
596+ #[ case] list_values : Vec < i32 > ,
597+ #[ case] needle : i32 ,
598+ #[ case] expected : bool ,
599+ #[ case] _description : & str ,
600+ ) {
601+ let expr = list_contains (
602+ lit ( Scalar :: list (
603+ Arc :: new ( DType :: Primitive ( I32 , Nullability :: NonNullable ) ) ,
604+ list_values
605+ . into_iter ( )
606+ . map ( |v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
607+ . collect ( ) ,
608+ Nullability :: NonNullable ,
609+ ) ) ,
610+ lit ( needle) ,
611+ ) ;
612+ let arr = test_array ( ) ;
613+ let result = expr. evaluate ( & arr) . unwrap ( ) ;
614+ assert_eq ! (
615+ result. scalar_at( 0 ) ,
616+ Scalar :: bool ( expected, Nullability :: Nullable )
617+ ) ;
618+ }
619+
620+ #[ rstest]
621+ #[ case( 0i32 , false , "empty_list" ) ]
622+ #[ case( 1i32 , false , "empty_list_seek_one" ) ]
623+ #[ case( 100i32 , false , "empty_list_seek_large" ) ]
624+ fn test_scalar_scalar_not_found (
625+ #[ case] needle : i32 ,
626+ #[ case] expected : bool ,
627+ #[ case] _description : & str ,
628+ ) {
629+ let expr = list_contains (
630+ lit ( Scalar :: list_empty (
631+ Arc :: new ( DType :: Primitive ( I32 , Nullability :: NonNullable ) ) ,
632+ Nullability :: NonNullable ,
633+ ) ) ,
634+ lit ( needle) ,
635+ ) ;
636+ let arr = test_array ( ) ;
637+ let result = expr. evaluate ( & arr) . unwrap ( ) ;
638+ assert_eq ! (
639+ result. scalar_at( 0 ) ,
640+ Scalar :: bool ( expected, Nullability :: Nullable )
641+ ) ;
642+ }
559643}
0 commit comments