Skip to content

Commit 443217d

Browse files
committed
perf: list contains scalar scalar
Signed-off-by: Alexander Droste <[email protected]>
1 parent fb976a1 commit 443217d

File tree

1 file changed

+102
-17
lines changed

1 file changed

+102
-17
lines changed

vortex-array/src/expr/exprs/list_contains.rs

Lines changed: 102 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ use vortex_error::vortex_err;
2020
use vortex_mask::Mask;
2121
use vortex_vector::BoolDatum;
2222
use vortex_vector::Datum;
23-
use vortex_vector::ScalarOps;
2423
use vortex_vector::Vector;
25-
use vortex_vector::VectorMutOps;
2624
use vortex_vector::VectorOps;
25+
use vortex_vector::bool::BoolScalar;
2726
use vortex_vector::bool::BoolVector;
2827
use vortex_vector::listview::ListViewScalar;
2928
use vortex_vector::listview::ListViewVector;
29+
use vortex_vector::match_each_pvector;
30+
use vortex_vector::primitive::PScalar;
3031
use vortex_vector::primitive::PVector;
3132

3233
use crate::ArrayRef;
@@ -128,30 +129,34 @@ impl VTable for ListContains {
128129
.try_into()
129130
.map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
130131

131-
let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
132+
match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
132133
(true, true) => {
134+
// Early return with Scalar to avoid allocating BitBuffer.
133135
let list = lhs.into_scalar().vortex_expect("scalar").into_list();
134136
let needle = rhs.into_scalar().vortex_expect("scalar");
135-
// Convert the needle scalar to a vector with row_count
136-
// elements and reuse constant_list_scalar_contains
137-
let needle_vector = needle.repeat(args.row_count).freeze();
138-
constant_list_scalar_contains(list, needle_vector)?
137+
let found = list_contains_scalar_scalar(&list, &needle)?;
138+
Ok(Datum::Scalar(BoolScalar::new(Some(found)).into()))
139+
}
140+
(true, false) => {
141+
let matches = constant_list_scalar_contains(
142+
lhs.into_scalar().vortex_expect("scalar").into_list(),
143+
rhs.into_vector().vortex_expect("vector"),
144+
)?;
145+
Ok(Datum::Vector(matches.into()))
146+
}
147+
(false, true) => {
148+
let matches = list_contains_scalar(
149+
lhs.unwrap_into_vector(args.row_count).into_list(),
150+
rhs.into_scalar().vortex_expect("scalar").into_list(),
151+
)?;
152+
Ok(Datum::Vector(matches.into()))
139153
}
140-
(true, false) => constant_list_scalar_contains(
141-
lhs.into_scalar().vortex_expect("scalar").into_list(),
142-
rhs.into_vector().vortex_expect("vector"),
143-
)?,
144-
(false, true) => list_contains_scalar(
145-
lhs.unwrap_into_vector(args.row_count).into_list(),
146-
rhs.into_scalar().vortex_expect("scalar").into_list(),
147-
)?,
148154
(false, false) => {
149155
vortex_bail!(
150156
"ListContains currently only supports constant needle (RHS) or constant list (LHS)"
151157
)
152158
}
153-
};
154-
Ok(Datum::Vector(matches.into()))
159+
}
155160
}
156161

157162
fn stat_falsification(
@@ -330,6 +335,32 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
330335
Ok(result)
331336
}
332337

338+
/// Used when both needle and list are scalars.
339+
fn list_contains_scalar_scalar(
340+
list: &ListViewScalar,
341+
needle: &vortex_vector::Scalar,
342+
) -> VortexResult<bool> {
343+
assert!(false);
344+
let elements = list.value().elements();
345+
346+
// Downcast to `PVector` and access slice directly to avoid `scalar_at` overhead.
347+
let found = if let Vector::Primitive(prim) = &**elements {
348+
match_each_pvector!(prim, |pvec| {
349+
let slice: &[_] = pvec.as_ref();
350+
let validity = pvec.validity();
351+
slice
352+
.iter()
353+
.enumerate()
354+
.any(|(i, &elem)| needle == &PScalar::new(Some(elem)).into() && validity.value(i))
355+
})
356+
} else {
357+
// Fallback for non-primitive vectors
358+
(0..elements.len()).any(|i| needle == &elements.scalar_at(i))
359+
};
360+
361+
Ok(found)
362+
}
363+
333364
/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
334365
/// [`BoolArray`] of matches on the child elements array.
335366
///
@@ -366,6 +397,7 @@ where
366397
mod tests {
367398
use std::sync::Arc;
368399

400+
use rstest::rstest;
369401
use vortex_buffer::BitBuffer;
370402
use vortex_dtype::DType;
371403
use vortex_dtype::Field;
@@ -556,4 +588,57 @@ mod tests {
556588
let expr2 = list_contains(root(), lit(42));
557589
assert_eq!(expr2.to_string(), "contains($, 42i32)");
558590
}
591+
592+
#[rstest]
593+
#[case(vec![1i32, 2i32, 3i32], 1i32, true, "first_element")]
594+
#[case(vec![1i32, 2i32, 3i32], 2i32, true, "middle_element")]
595+
#[case(vec![1i32, 2i32, 3i32], 3i32, true, "last_element")]
596+
fn test_scalar_scalar_found(
597+
#[case] list_values: Vec<i32>,
598+
#[case] needle: i32,
599+
#[case] expected: bool,
600+
#[case] _description: &str,
601+
) {
602+
let expr = list_contains(
603+
lit(Scalar::list(
604+
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
605+
list_values
606+
.into_iter()
607+
.map(|v| Scalar::primitive(v, Nullability::NonNullable))
608+
.collect(),
609+
Nullability::NonNullable,
610+
)),
611+
lit(needle),
612+
);
613+
let arr = test_array();
614+
let result = expr.evaluate(&arr).unwrap();
615+
assert_eq!(
616+
result.scalar_at(0),
617+
Scalar::bool(expected, Nullability::Nullable)
618+
);
619+
}
620+
621+
#[rstest]
622+
#[case(0i32, false, "empty_list")]
623+
#[case(1i32, false, "empty_list_seek_one")]
624+
#[case(100i32, false, "empty_list_seek_large")]
625+
fn test_scalar_scalar_not_found(
626+
#[case] needle: i32,
627+
#[case] expected: bool,
628+
#[case] _description: &str,
629+
) {
630+
let expr = list_contains(
631+
lit(Scalar::list_empty(
632+
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
633+
Nullability::NonNullable,
634+
)),
635+
lit(needle),
636+
);
637+
let arr = test_array();
638+
let result = expr.evaluate(&arr).unwrap();
639+
assert_eq!(
640+
result.scalar_at(0),
641+
Scalar::bool(expected, Nullability::Nullable)
642+
);
643+
}
559644
}

0 commit comments

Comments
 (0)