Skip to content

Commit 6e2a2a0

Browse files
feat[array]: list contains scalar (#5713)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 1f61d99 commit 6e2a2a0

File tree

8 files changed

+38
-16
lines changed

8 files changed

+38
-16
lines changed

vortex-array/src/arrays/scalar_fn/kernel.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ impl Kernel for ScalarFnKernel {
5757
return_dtype: self.return_dtype,
5858
};
5959

60-
Ok(self.scalar_fn.execute(args)?.ensure_vector(self.row_count))
60+
Ok(self
61+
.scalar_fn
62+
.execute(args)?
63+
.unwrap_into_vector(self.row_count))
6164
}
6265

6366
fn push_down_filter(self: Box<Self>, selection: &Mask) -> VortexResult<PushDownResult> {

vortex-array/src/arrays/scalar_fn/vtable/canonical.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ impl CanonicalVTable<ScalarFnVTable> for ScalarFnVTable {
3737
.scalar_fn
3838
.execute(ctx)
3939
.vortex_expect("Canonicalize should be fallible")
40-
.ensure_vector(len);
40+
.unwrap_into_vector(len);
4141

4242
result_vector.into_array(&array.dtype).to_canonical()
4343
}

vortex-array/src/executor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ impl VectorExecutor for ArrayRef {
5555

5656
fn execute_vector(&self, session: &VortexSession) -> VortexResult<Vector> {
5757
let len = self.len();
58-
Ok(self.execute_datum(session)?.ensure_vector(len))
58+
Ok(self.execute_datum(session)?.unwrap_into_vector(len))
5959
}
6060
}

vortex-array/src/expr/exprs/binary.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,14 @@ impl VTable for Binary {
137137
match op {
138138
Operator::And => {
139139
// FIXME(ngates): implement logical compute over datums
140-
let lhs = lhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
141-
let rhs = rhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
140+
let lhs = lhs
141+
.unwrap_into_vector(args.row_count)
142+
.into_bool()
143+
.into_arrow()?;
144+
let rhs = rhs
145+
.unwrap_into_vector(args.row_count)
146+
.into_bool()
147+
.into_arrow()?;
142148
return Ok(Datum::Vector(
143149
arrow_arith::boolean::and_kleene(&lhs, &rhs)?
144150
.into_vector()?
@@ -147,8 +153,14 @@ impl VTable for Binary {
147153
}
148154
Operator::Or => {
149155
// FIXME(ngates): implement logical compute over datums
150-
let lhs = lhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
151-
let rhs = rhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
156+
let lhs = lhs
157+
.unwrap_into_vector(args.row_count)
158+
.into_bool()
159+
.into_arrow()?;
160+
let rhs = rhs
161+
.unwrap_into_vector(args.row_count)
162+
.into_bool()
163+
.into_arrow()?;
152164
return Ok(Datum::Vector(
153165
arrow_arith::boolean::or_kleene(&lhs, &rhs)?
154166
.into_vector()?

vortex-array/src/expr/exprs/list_contains.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ use vortex_error::vortex_err;
2020
use vortex_mask::Mask;
2121
use vortex_vector::BoolDatum;
2222
use vortex_vector::Datum;
23+
use vortex_vector::ScalarOps;
2324
use vortex_vector::Vector;
25+
use vortex_vector::VectorMutOps;
2426
use vortex_vector::VectorOps;
2527
use vortex_vector::bool::BoolVector;
2628
use vortex_vector::listview::ListViewScalar;
@@ -128,14 +130,19 @@ impl VTable for ListContains {
128130

129131
let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
130132
(true, true) => {
131-
todo!("Implement ListContains for two scalars")
133+
let list = lhs.into_scalar().vortex_expect("scalar").into_list();
134+
let needle = rhs.into_scalar().vortex_expect("scalar");
135+
// Convert the needle scalar to a single-element vector and reuse
136+
// constant_list_scalar_contains
137+
let needle_vector = needle.repeat(1).freeze();
138+
constant_list_scalar_contains(list, needle_vector)
132139
}
133140
(true, false) => constant_list_scalar_contains(
134141
lhs.into_scalar().vortex_expect("scalar").into_list(),
135142
rhs.into_vector().vortex_expect("vector"),
136143
),
137144
(false, true) => list_contains_scalar(
138-
lhs.ensure_vector(args.row_count).into_list(),
145+
lhs.unwrap_into_vector(args.row_count).into_list(),
139146
rhs.into_scalar().vortex_expect("scalar").into_list(),
140147
),
141148
(false, false) => {
@@ -234,7 +241,7 @@ fn list_contains_scalar(list: ListViewVector, value: ListViewScalar) -> VortexRe
234241
row_count: elems.len(),
235242
return_dtype: DType::Bool(Nullability::Nullable),
236243
})?
237-
.ensure_vector(elems.len())
244+
.unwrap_into_vector(elems.len())
238245
.into_bool()
239246
.into_bits();
240247

@@ -311,7 +318,7 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
311318
})?
312319
.into_bool();
313320
let compared = Datum::from(compared)
314-
.ensure_vector(values.len())
321+
.unwrap_into_vector(values.len())
315322
.into_bool();
316323

317324
result = LogicalOr::or(&result, &compared);

vortex-array/src/expr/exprs/pack.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl VTable for Pack {
154154
let fields: Box<[_]> = args
155155
.datums
156156
.into_iter()
157-
.map(|v| v.ensure_vector(args.row_count))
157+
.map(|v| v.unwrap_into_vector(args.row_count))
158158
.collect();
159159
return Ok(Datum::Vector(
160160
StructVector::try_new(Arc::new(fields), Mask::new_true(args.row_count))?.into(),

vortex-vector/src/datum.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ impl From<Vector> for Datum {
5757

5858
impl Datum {
5959
/// Ensure the datum is a vector by repeating the scalar value if necessary.
60-
pub fn ensure_vector(self, len: usize) -> Vector {
60+
pub fn unwrap_into_vector(self, len: usize) -> Vector {
6161
match self {
6262
Datum::Scalar(scalar) => scalar.repeat(len).freeze(),
6363
Datum::Vector(vector) => {

vortex-vector/src/decimal/generic.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ impl<D: NativeDecimalType> DVector<D> {
9797
}
9898

9999
// We assert that each element is within bounds for the given precision/scale.
100-
if !elements.iter().all(|e| ps.is_valid(*e)) {
100+
if let Some(invalid) = elements.iter().find(|e| !ps.is_valid(**e)) {
101101
vortex_bail!(
102-
"One or more elements are out of bounds for precision {} and scale {}",
102+
"One or more elements (e.g. {invalid}) are out of bounds for precision {} and scale {}",
103103
ps.precision(),
104-
ps.scale()
104+
ps.scale(),
105105
);
106106
}
107107

0 commit comments

Comments
 (0)