@@ -13,17 +13,15 @@ use vortex_dtype::IntegerPType;
1313use vortex_dtype:: Nullability ;
1414use vortex_dtype:: PTypeDowncastExt ;
1515use vortex_dtype:: match_each_integer_ptype;
16- use vortex_error:: VortexExpect ;
1716use vortex_error:: VortexResult ;
1817use vortex_error:: vortex_bail;
1918use vortex_error:: vortex_err;
2019use vortex_mask:: Mask ;
2120use vortex_vector:: BoolDatum ;
2221use vortex_vector:: Datum ;
23- use vortex_vector:: ScalarOps ;
2422use vortex_vector:: Vector ;
25- use vortex_vector:: VectorMutOps ;
2623use vortex_vector:: VectorOps ;
24+ use vortex_vector:: bool:: BoolScalar ;
2725use vortex_vector:: bool:: BoolVector ;
2826use vortex_vector:: listview:: ListViewScalar ;
2927use vortex_vector:: listview:: ListViewVector ;
@@ -128,30 +126,28 @@ impl VTable for ListContains {
128126 . try_into ( )
129127 . map_err ( |_| vortex_err ! ( "Wrong number of arguments for ListContains expression" ) ) ?;
130128
131- let matches = match ( lhs. as_scalar ( ) . is_some ( ) , rhs. as_scalar ( ) . is_some ( ) ) {
132- ( true , true ) => {
133- let list = lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ;
134- let needle = rhs. into_scalar ( ) . vortex_expect ( "scalar" ) ;
135- // Convert the needle scalar to a vector with row_count
136- // elements and reuse constant_list_scalar_contains
137- let needle_vector = needle. repeat ( args. row_count ) . freeze ( ) ;
138- constant_list_scalar_contains ( list, needle_vector) ?
129+ match ( lhs, rhs) {
130+ ( Datum :: Scalar ( list_scalar) , Datum :: Scalar ( needle_scalar) ) => {
131+ let list = list_scalar. into_list ( ) ;
132+ let found = list_contains_scalar_scalar ( & list, & needle_scalar) ?;
133+ Ok ( Datum :: Scalar ( BoolScalar :: new ( Some ( found) ) . into ( ) ) )
139134 }
140- ( true , false ) => constant_list_scalar_contains (
141- lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
142- rhs. into_vector ( ) . vortex_expect ( "vector" ) ,
143- ) ?,
144- ( false , true ) => list_contains_scalar (
145- lhs. unwrap_into_vector ( args. row_count ) . into_list ( ) ,
146- rhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
147- ) ?,
148- ( false , false ) => {
135+ ( Datum :: Scalar ( list_scalar) , Datum :: Vector ( needle_vector) ) => {
136+ let matches =
137+ constant_list_scalar_contains ( list_scalar. into_list ( ) , needle_vector) ?;
138+ Ok ( Datum :: Vector ( matches. into ( ) ) )
139+ }
140+ ( Datum :: Vector ( list_vector) , Datum :: Scalar ( needle_scalar) ) => {
141+ let matches =
142+ list_contains_scalar ( list_vector. into_list ( ) , needle_scalar. into_list ( ) ) ?;
143+ Ok ( Datum :: Vector ( matches. into ( ) ) )
144+ }
145+ ( Datum :: Vector ( _) , Datum :: Vector ( _) ) => {
149146 vortex_bail ! (
150147 "ListContains currently only supports constant needle (RHS) or constant list (LHS)"
151148 )
152149 }
153- } ;
154- Ok ( Datum :: Vector ( matches. into ( ) ) )
150+ }
155151 }
156152
157153 fn stat_falsification (
@@ -330,6 +326,35 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
330326 Ok ( result)
331327}
332328
329+ /// Used when the needle is a scalar checked for containment in a single list.
330+ fn list_contains_scalar_scalar (
331+ list : & ListViewScalar ,
332+ needle : & vortex_vector:: Scalar ,
333+ ) -> VortexResult < bool > {
334+ let elements = list. value ( ) . elements ( ) ;
335+
336+ // Note: If the comparison becomes a bottleneck, look into faster ways to check for list
337+ // containment. `execute` allocates the returned vector on the heap. Further, the `eq`
338+ // comparison does not short-circuit on the first match found.
339+ let found = Binary
340+ . bind ( operators:: Operator :: Eq )
341+ . execute ( ExecutionArgs {
342+ datums : vec ! [
343+ Datum :: Vector ( elements. deref( ) . clone( ) ) ,
344+ Datum :: Scalar ( needle. clone( ) ) ,
345+ ] ,
346+ dtypes : vec ! [ ] ,
347+ row_count : elements. len ( ) ,
348+ return_dtype : DType :: Bool ( Nullability :: Nullable ) ,
349+ } ) ?
350+ . unwrap_into_vector ( elements. len ( ) )
351+ . into_bool ( )
352+ . into_bits ( ) ;
353+
354+ let mut true_bits = BitIndexIterator :: new ( found. inner ( ) . as_ref ( ) , 0 , found. len ( ) ) ;
355+ Ok ( true_bits. next ( ) . is_some ( ) )
356+ }
357+
333358/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
334359/// [`BoolArray`] of matches on the child elements array.
335360///
0 commit comments