Skip to content

Commit 7698e3c

Browse files
committed
perf: list contains scalar scalar
Signed-off-by: Alexander Droste <[email protected]>
1 parent fb976a1 commit 7698e3c

File tree

1 file changed

+101
-17
lines changed

1 file changed

+101
-17
lines changed

vortex-array/src/expr/exprs/list_contains.rs

Lines changed: 101 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ use vortex_error::vortex_err;
2020
use vortex_mask::Mask;
2121
use vortex_vector::BoolDatum;
2222
use vortex_vector::Datum;
23-
use vortex_vector::ScalarOps;
2423
use vortex_vector::Vector;
25-
use vortex_vector::VectorMutOps;
2624
use vortex_vector::VectorOps;
25+
use vortex_vector::bool::BoolScalar;
2726
use vortex_vector::bool::BoolVector;
2827
use vortex_vector::listview::ListViewScalar;
2928
use vortex_vector::listview::ListViewVector;
29+
use vortex_vector::match_each_pvector;
30+
use vortex_vector::primitive::PScalar;
3031
use vortex_vector::primitive::PVector;
3132

3233
use crate::ArrayRef;
@@ -128,30 +129,34 @@ impl VTable for ListContains {
128129
.try_into()
129130
.map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
130131

131-
let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
132+
match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
132133
(true, true) => {
134+
// Early return with Scalar to avoid allocating BitBuffer.
133135
let list = lhs.into_scalar().vortex_expect("scalar").into_list();
134136
let needle = rhs.into_scalar().vortex_expect("scalar");
135-
// Convert the needle scalar to a vector with row_count
136-
// elements and reuse constant_list_scalar_contains
137-
let needle_vector = needle.repeat(args.row_count).freeze();
138-
constant_list_scalar_contains(list, needle_vector)?
137+
let found = list_contains_scalar_scalar(&list, &needle)?;
138+
Ok(Datum::Scalar(BoolScalar::new(Some(found)).into()))
139+
}
140+
(true, false) => {
141+
let matches = constant_list_scalar_contains(
142+
lhs.into_scalar().vortex_expect("scalar").into_list(),
143+
rhs.into_vector().vortex_expect("vector"),
144+
)?;
145+
Ok(Datum::Vector(matches.into()))
146+
}
147+
(false, true) => {
148+
let matches = list_contains_scalar(
149+
lhs.unwrap_into_vector(args.row_count).into_list(),
150+
rhs.into_scalar().vortex_expect("scalar").into_list(),
151+
)?;
152+
Ok(Datum::Vector(matches.into()))
139153
}
140-
(true, false) => constant_list_scalar_contains(
141-
lhs.into_scalar().vortex_expect("scalar").into_list(),
142-
rhs.into_vector().vortex_expect("vector"),
143-
)?,
144-
(false, true) => list_contains_scalar(
145-
lhs.unwrap_into_vector(args.row_count).into_list(),
146-
rhs.into_scalar().vortex_expect("scalar").into_list(),
147-
)?,
148154
(false, false) => {
149155
vortex_bail!(
150156
"ListContains currently only supports constant needle (RHS) or constant list (LHS)"
151157
)
152158
}
153-
};
154-
Ok(Datum::Vector(matches.into()))
159+
}
155160
}
156161

157162
fn stat_falsification(
@@ -330,6 +335,31 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
330335
Ok(result)
331336
}
332337

338+
/// Used when both needle and list are scalars.
339+
fn list_contains_scalar_scalar(
340+
list: &ListViewScalar,
341+
needle: &vortex_vector::Scalar,
342+
) -> VortexResult<bool> {
343+
let elements = list.value().elements();
344+
345+
// Downcast to `PVector` and access slice directly to avoid `scalar_at` overhead.
346+
let found = if let Vector::Primitive(prim) = &**elements {
347+
match_each_pvector!(prim, |pvec| {
348+
let slice: &[_] = pvec.as_ref();
349+
let validity = pvec.validity();
350+
slice
351+
.iter()
352+
.enumerate()
353+
.any(|(i, &elem)| needle == &PScalar::new(Some(elem)).into() && validity.value(i))
354+
})
355+
} else {
356+
// Fallback for non-primitive vectors
357+
(0..elements.len()).any(|i| needle == &elements.scalar_at(i))
358+
};
359+
360+
Ok(found)
361+
}
362+
333363
/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
334364
/// [`BoolArray`] of matches on the child elements array.
335365
///
@@ -366,6 +396,7 @@ where
366396
mod tests {
367397
use std::sync::Arc;
368398

399+
use rstest::rstest;
369400
use vortex_buffer::BitBuffer;
370401
use vortex_dtype::DType;
371402
use vortex_dtype::Field;
@@ -556,4 +587,57 @@ mod tests {
556587
let expr2 = list_contains(root(), lit(42));
557588
assert_eq!(expr2.to_string(), "contains($, 42i32)");
558589
}
590+
591+
#[rstest]
592+
#[case(vec![1i32, 2i32, 3i32], 1i32, true, "first_element")]
593+
#[case(vec![1i32, 2i32, 3i32], 2i32, true, "middle_element")]
594+
#[case(vec![1i32, 2i32, 3i32], 3i32, true, "last_element")]
595+
fn test_scalar_scalar_found(
596+
#[case] list_values: Vec<i32>,
597+
#[case] needle: i32,
598+
#[case] expected: bool,
599+
#[case] _description: &str,
600+
) {
601+
let expr = list_contains(
602+
lit(Scalar::list(
603+
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
604+
list_values
605+
.into_iter()
606+
.map(|v| Scalar::primitive(v, Nullability::NonNullable))
607+
.collect(),
608+
Nullability::NonNullable,
609+
)),
610+
lit(needle),
611+
);
612+
let arr = test_array();
613+
let result = expr.evaluate(&arr).unwrap();
614+
assert_eq!(
615+
result.scalar_at(0),
616+
Scalar::bool(expected, Nullability::Nullable)
617+
);
618+
}
619+
620+
#[rstest]
621+
#[case(0i32, false, "empty_list")]
622+
#[case(1i32, false, "empty_list_seek_one")]
623+
#[case(100i32, false, "empty_list_seek_large")]
624+
fn test_scalar_scalar_not_found(
625+
#[case] needle: i32,
626+
#[case] expected: bool,
627+
#[case] _description: &str,
628+
) {
629+
let expr = list_contains(
630+
lit(Scalar::list_empty(
631+
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
632+
Nullability::NonNullable,
633+
)),
634+
lit(needle),
635+
);
636+
let arr = test_array();
637+
let result = expr.evaluate(&arr).unwrap();
638+
assert_eq!(
639+
result.scalar_at(0),
640+
Scalar::bool(expected, Nullability::Nullable)
641+
);
642+
}
559643
}

0 commit comments

Comments
 (0)