@@ -7,15 +7,22 @@ use std::ops::Deref;
77
88use arrow_buffer:: bit_iterator:: BitIndexIterator ;
99use vortex_buffer:: BitBuffer ;
10+ use vortex_compute:: comparison:: Compare ;
11+ use vortex_compute:: comparison:: Equal ;
12+ use vortex_compute:: logical:: LogicalOr ;
1013use vortex_dtype:: DType ;
1114use vortex_dtype:: IntegerPType ;
1215use vortex_dtype:: Nullability ;
1316use vortex_dtype:: PTypeDowncastExt ;
1417use vortex_dtype:: match_each_integer_ptype;
18+ use vortex_error:: VortexExpect ;
1519use vortex_error:: VortexResult ;
1620use vortex_error:: vortex_bail;
1721use vortex_error:: vortex_err;
22+ use vortex_mask:: Mask ;
23+ use vortex_vector:: BoolDatum ;
1824use vortex_vector:: Datum ;
25+ use vortex_vector:: Vector ;
1926use vortex_vector:: VectorOps ;
2027use vortex_vector:: bool:: BoolVector ;
2128use vortex_vector:: listview:: ListViewScalar ;
@@ -121,15 +128,26 @@ impl VTable for ListContains {
121128 . try_into ( )
122129 . map_err ( |_| vortex_err ! ( "Wrong number of arguments for ListContains expression" ) ) ?;
123130
124- let rhs = rhs
125- . into_scalar ( )
126- . ok_or_else ( || vortex_err ! ( "Only supports constant RHS" ) ) ?;
131+ let matches = match ( lhs. as_scalar ( ) . is_some ( ) , rhs. as_scalar ( ) . is_some ( ) ) {
132+ ( true , true ) => {
133+ todo ! ( "Implement ListContains for two scalars" )
134+ }
135+ ( true , false ) => constant_list_scalar_contains (
136+ lhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
137+ rhs. into_vector ( ) . vortex_expect ( "vector" ) ,
138+ ) ,
139+ ( false , true ) => list_contains_scalar (
140+ lhs. ensure_vector ( args. row_count ) . into_list ( ) ,
141+ rhs. into_scalar ( ) . vortex_expect ( "scalar" ) . into_list ( ) ,
142+ ) ,
143+ ( false , false ) => {
144+ vortex_bail ! (
145+ "ListContains currently only supports constant needle (RHS) or constant list (LHS)"
146+ )
147+ }
148+ } ?;
127149
128- let result = list_contains_scalar (
129- lhs. ensure_vector ( args. row_count ) . into_list ( ) ,
130- rhs. into_list ( ) ,
131- ) ?;
132- Ok ( Datum :: Vector ( result. into ( ) ) )
150+ Ok ( Datum :: Vector ( matches. into ( ) ) )
133151 }
134152
135153 fn stat_falsification (
@@ -190,7 +208,8 @@ pub fn list_contains(list: Expression, value: Expression) -> Expression {
190208 ListContains . new_expr ( EmptyOptions , [ list, value] )
191209}
192210
193- /// Returns a [`BoolArray`] where each bit represents if a list contains the scalar.
211+ /// Returns a [`BoolVector`] where each bit represents if a list contains the scalar.
212+ // FIXME(ngates): test implementation and move to vortex-compute
194213fn list_contains_scalar ( list : ListViewVector , value : ListViewScalar ) -> VortexResult < BoolVector > {
195214 // If the list array is constant, we perform a single comparison.
196215 // if list.len() > 1 && list.is_constant() {
@@ -269,6 +288,30 @@ fn list_contains_scalar(list: ListViewVector, value: ListViewScalar) -> VortexRe
269288 Ok ( BoolVector :: new ( list_matches, list. validity ( ) . clone ( ) ) )
270289}
271290
291+ // Then there is a constant list scalar (haystack) being compared to an array of needles.
292+ // FIXME(ngates): test implementation and move to vortex-compute
293+ fn constant_list_scalar_contains ( list : ListViewScalar , values : Vector ) -> VortexResult < BoolVector > {
294+ let elements = list. value ( ) . elements ( ) ;
295+
296+ // For each element in the list, we perform a full comparison over the values and OR
297+ // the results together.
298+ let mut result: BoolVector = BoolVector :: new (
299+ BitBuffer :: new_unset ( values. len ( ) ) ,
300+ Mask :: new ( values. len ( ) , false ) ,
301+ ) ;
302+ for i in 0 ..elements. len ( ) {
303+ let element = Datum :: Scalar ( elements. scalar_at ( i) ) ;
304+ let compared: BoolDatum = Compare :: < Equal > :: compare ( Datum :: Vector ( values. clone ( ) ) , element) ;
305+ let compared = Datum :: from ( compared)
306+ . ensure_vector ( values. len ( ) )
307+ . into_bool ( ) ;
308+
309+ result = LogicalOr :: or ( result, & compared) ;
310+ }
311+
312+ Ok ( result)
313+ }
314+
272315/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
273316/// [`BoolArray`] of matches on the child elements array.
274317///
0 commit comments