Skip to content

Commit b12219f

Browse files
authored
perf: list contains scalar scalar (#5792)
Signed-off-by: Alexander Droste <[email protected]>
1 parent b6984f2 commit b12219f

File tree

1 file changed

+47
-22
lines changed

1 file changed

+47
-22
lines changed

vortex-array/src/expr/exprs/list_contains.rs

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@ use vortex_dtype::IntegerPType;
1313
use vortex_dtype::Nullability;
1414
use vortex_dtype::PTypeDowncastExt;
1515
use vortex_dtype::match_each_integer_ptype;
16-
use vortex_error::VortexExpect;
1716
use vortex_error::VortexResult;
1817
use vortex_error::vortex_bail;
1918
use vortex_error::vortex_err;
2019
use vortex_mask::Mask;
2120
use vortex_vector::BoolDatum;
2221
use vortex_vector::Datum;
23-
use vortex_vector::ScalarOps;
2422
use vortex_vector::Vector;
25-
use vortex_vector::VectorMutOps;
2623
use vortex_vector::VectorOps;
24+
use vortex_vector::bool::BoolScalar;
2725
use vortex_vector::bool::BoolVector;
2826
use vortex_vector::listview::ListViewScalar;
2927
use vortex_vector::listview::ListViewVector;
@@ -128,30 +126,28 @@ impl VTable for ListContains {
128126
.try_into()
129127
.map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
130128

131-
let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
132-
(true, true) => {
133-
let list = lhs.into_scalar().vortex_expect("scalar").into_list();
134-
let needle = rhs.into_scalar().vortex_expect("scalar");
135-
// Convert the needle scalar to a vector with row_count
136-
// elements and reuse constant_list_scalar_contains
137-
let needle_vector = needle.repeat(args.row_count).freeze();
138-
constant_list_scalar_contains(list, needle_vector)?
129+
match (lhs, rhs) {
130+
(Datum::Scalar(list_scalar), Datum::Scalar(needle_scalar)) => {
131+
let list = list_scalar.into_list();
132+
let found = list_contains_scalar_scalar(&list, &needle_scalar)?;
133+
Ok(Datum::Scalar(BoolScalar::new(Some(found)).into()))
139134
}
140-
(true, false) => constant_list_scalar_contains(
141-
lhs.into_scalar().vortex_expect("scalar").into_list(),
142-
rhs.into_vector().vortex_expect("vector"),
143-
)?,
144-
(false, true) => list_contains_scalar(
145-
lhs.unwrap_into_vector(args.row_count).into_list(),
146-
rhs.into_scalar().vortex_expect("scalar").into_list(),
147-
)?,
148-
(false, false) => {
135+
(Datum::Scalar(list_scalar), Datum::Vector(needle_vector)) => {
136+
let matches =
137+
constant_list_scalar_contains(list_scalar.into_list(), needle_vector)?;
138+
Ok(Datum::Vector(matches.into()))
139+
}
140+
(Datum::Vector(list_vector), Datum::Scalar(needle_scalar)) => {
141+
let matches =
142+
list_contains_scalar(list_vector.into_list(), needle_scalar.into_list())?;
143+
Ok(Datum::Vector(matches.into()))
144+
}
145+
(Datum::Vector(_), Datum::Vector(_)) => {
149146
vortex_bail!(
150147
"ListContains currently only supports constant needle (RHS) or constant list (LHS)"
151148
)
152149
}
153-
};
154-
Ok(Datum::Vector(matches.into()))
150+
}
155151
}
156152

157153
fn stat_falsification(
@@ -330,6 +326,35 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
330326
Ok(result)
331327
}
332328

329+
/// Used when the needle is a scalar checked for containment in a single list.
330+
fn list_contains_scalar_scalar(
331+
list: &ListViewScalar,
332+
needle: &vortex_vector::Scalar,
333+
) -> VortexResult<bool> {
334+
let elements = list.value().elements();
335+
336+
// Note: If the comparison becomes a bottleneck, look into faster ways to check for list
337+
// containment. `execute` allocates the returned vector on the heap. Further, the `eq`
338+
// comparison does not short-circuit on the first match found.
339+
let found = Binary
340+
.bind(operators::Operator::Eq)
341+
.execute(ExecutionArgs {
342+
datums: vec![
343+
Datum::Vector(elements.deref().clone()),
344+
Datum::Scalar(needle.clone()),
345+
],
346+
dtypes: vec![],
347+
row_count: elements.len(),
348+
return_dtype: DType::Bool(Nullability::Nullable),
349+
})?
350+
.unwrap_into_vector(elements.len())
351+
.into_bool()
352+
.into_bits();
353+
354+
let mut true_bits = BitIndexIterator::new(found.inner().as_ref(), 0, found.len());
355+
Ok(true_bits.next().is_some())
356+
}
357+
333358
/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
334359
/// [`BoolArray`] of matches on the child elements array.
335360
///

0 commit comments

Comments
 (0)