Skip to content

Commit 93f8cb5

Browse files
authored
feat: eagerly compute IsConstant stat (#1838)
re-attempt of #1492 fixes a bug where compute functions that fall back to arrow can erroneously return an array of length 1 if both inputs are constant
1 parent bcea32a commit 93f8cb5

File tree

6 files changed

+155
-116
lines changed

6 files changed

+155
-116
lines changed

vortex-array/src/array/varbin/compute/compare.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use vortex_dtype::DType;
44
use vortex_error::{vortex_bail, VortexResult};
55

66
use crate::array::{VarBinArray, VarBinEncoding};
7-
use crate::arrow::{Datum, FromArrowArray};
7+
use crate::arrow::{from_arrow_array_with_len, Datum};
88
use crate::compute::{CompareFn, Operator};
9-
use crate::{ArrayDType, ArrayData, IntoArrayData};
9+
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData};
1010

1111
// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
1212
impl CompareFn<VarBinArray> for VarBinEncoding {
@@ -18,8 +18,8 @@ impl CompareFn<VarBinArray> for VarBinEncoding {
1818
) -> VortexResult<Option<ArrayData>> {
1919
if let Some(rhs_const) = rhs.as_constant() {
2020
let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
21-
22-
let lhs = Datum::try_from(lhs.clone().into_array())?;
21+
let len = lhs.len();
22+
let lhs = unsafe { Datum::try_new(lhs.clone().into_array())? };
2323

2424
// TODO(robert): Handle LargeString/Binary arrays
2525
let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
@@ -48,7 +48,7 @@ impl CompareFn<VarBinArray> for VarBinEncoding {
4848
Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs)?,
4949
};
5050

51-
Ok(Some(ArrayData::from_arrow(&array, nullable)))
51+
Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
5252
} else {
5353
Ok(None)
5454
}

vortex-array/src/arrow/datum.rs

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use arrow_array::{Array, ArrayRef, Datum as ArrowDatum};
2-
use vortex_error::VortexError;
2+
use vortex_error::{vortex_panic, VortexResult};
33

4-
use crate::compute::slice;
5-
use crate::{ArrayData, IntoCanonical};
4+
use crate::array::ConstantArray;
5+
use crate::arrow::FromArrowArray;
6+
use crate::compute::{scalar_at, slice};
7+
use crate::{ArrayData, IntoArrayData, IntoCanonical};
68

79
/// A wrapper around a generic Arrow array that can be used as a Datum in Arrow compute.
810
#[derive(Debug)]
@@ -11,10 +13,18 @@ pub struct Datum {
1113
is_scalar: bool,
1214
}
1315

14-
impl TryFrom<ArrayData> for Datum {
15-
type Error = VortexError;
16-
17-
fn try_from(array: ArrayData) -> Result<Self, Self::Error> {
16+
impl Datum {
17+
/// Create a new [`Datum`] from an [`ArrayData`], which can then be passed to Arrow compute.
18+
/// This is unsafe because it does not preserve the length of the array.
19+
///
20+
/// # Safety
21+
/// The caller must ensure that the length of the array is preserved, and when processing
22+
/// the result of the Arrow compute, must check whether the result is a scalar (Arrow array of length 1),
23+
/// in which case it likely must be expanded to match the length of the original array.
24+
///
25+
/// The utility function [`from_arrow_array_with_len`] can be used to ensure that the length of the
26+
/// result of the Arrow compute matches the length of the original array.
27+
pub unsafe fn try_new(array: ArrayData) -> VortexResult<Self> {
1828
if array.is_constant() {
1929
Ok(Self {
2030
array: slice(array, 0, 1)?.into_arrow()?,
@@ -34,3 +44,29 @@ impl ArrowDatum for Datum {
3444
(&self.array, self.is_scalar)
3545
}
3646
}
47+
48+
/// Convert an Arrow array to an ArrayData with a specific length.
49+
/// This is useful for compute functions that delegate to Arrow using [Datum],
50+
/// which will return a scalar (length 1 Arrow array) if the input array is constant.
51+
///
52+
/// Panics if the length of the array is not 1 and also not equal to the expected length.
53+
pub fn from_arrow_array_with_len<A>(array: A, len: usize, nullable: bool) -> VortexResult<ArrayData>
54+
where
55+
ArrayData: FromArrowArray<A>,
56+
{
57+
let array = ArrayData::from_arrow(array, nullable);
58+
if array.len() == len {
59+
return Ok(array);
60+
}
61+
62+
if array.len() != 1 {
63+
vortex_panic!(
64+
"Array length mismatch, expected {} got {} for encoding {}",
65+
len,
66+
array.len(),
67+
array.encoding().id()
68+
);
69+
}
70+
71+
Ok(ConstantArray::new(scalar_at(&array, 0)?, len).into_array())
72+
}

vortex-array/src/compute/binary_numeric.rs

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
use std::sync::Arc;
2-
3-
use arrow_array::ArrayRef;
41
use vortex_dtype::{DType, PType};
5-
use vortex_error::{vortex_bail, VortexError, VortexResult};
2+
use vortex_error::{vortex_bail, VortexError, VortexExpect, VortexResult};
63
use vortex_scalar::{BinaryNumericOperator, Scalar};
74

85
use crate::array::ConstantArray;
9-
use crate::arrow::{Datum, FromArrowArray};
6+
use crate::arrow::{from_arrow_array_with_len, Datum};
107
use crate::encoding::Encoding;
118
use crate::{ArrayDType, ArrayData, IntoArrayData as _};
129

@@ -121,43 +118,15 @@ pub fn binary_numeric(
121118
// Check if LHS supports the operation directly.
122119
if let Some(fun) = lhs.encoding().binary_numeric_fn() {
123120
if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
124-
debug_assert_eq!(
125-
result.len(),
126-
lhs.len(),
127-
"Numeric operation length mismatch {}",
128-
lhs.encoding().id()
129-
);
130-
debug_assert_eq!(
131-
result.dtype(),
132-
&DType::Primitive(
133-
PType::try_from(lhs.dtype())?,
134-
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
135-
),
136-
"Numeric operation dtype mismatch {}",
137-
lhs.encoding().id()
138-
);
121+
check_numeric_result(&result, lhs, rhs);
139122
return Ok(result);
140123
}
141124
}
142125

143126
// Check if RHS supports the operation directly.
144127
if let Some(fun) = rhs.encoding().binary_numeric_fn() {
145128
if let Some(result) = fun.binary_numeric(rhs, lhs, op.swap())? {
146-
debug_assert_eq!(
147-
result.len(),
148-
lhs.len(),
149-
"Numeric operation length mismatch {}",
150-
rhs.encoding().id()
151-
);
152-
debug_assert_eq!(
153-
result.dtype(),
154-
&DType::Primitive(
155-
PType::try_from(lhs.dtype())?,
156-
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
157-
),
158-
"Numeric operation dtype mismatch {}",
159-
rhs.encoding().id()
160-
);
129+
check_numeric_result(&result, lhs, rhs);
161130
return Ok(result);
162131
}
163132
}
@@ -183,20 +152,43 @@ fn arrow_numeric(
183152
operator: BinaryNumericOperator,
184153
) -> VortexResult<ArrayData> {
185154
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
155+
let len = lhs.len();
186156

187-
let lhs = Datum::try_from(lhs)?;
188-
let rhs = Datum::try_from(rhs)?;
157+
let left = unsafe { Datum::try_new(lhs.clone())? };
158+
let right = unsafe { Datum::try_new(rhs.clone())? };
189159

190160
let array = match operator {
191-
BinaryNumericOperator::Add => arrow_arith::numeric::add(&lhs, &rhs)?,
192-
BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&lhs, &rhs)?,
193-
BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&rhs, &lhs)?,
194-
BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&lhs, &rhs)?,
195-
BinaryNumericOperator::Div => arrow_arith::numeric::div(&lhs, &rhs)?,
196-
BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&rhs, &lhs)?,
161+
BinaryNumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
162+
BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
163+
BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
164+
BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
165+
BinaryNumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
166+
BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
197167
};
198168

199-
Ok(ArrayData::from_arrow(Arc::new(array) as ArrayRef, nullable))
169+
let result = from_arrow_array_with_len(array, len, nullable)?;
170+
check_numeric_result(&result, &lhs, &rhs);
171+
Ok(result)
172+
}
173+
174+
#[inline(always)]
175+
fn check_numeric_result(result: &ArrayData, lhs: &ArrayData, rhs: &ArrayData) {
176+
debug_assert_eq!(
177+
result.len(),
178+
lhs.len(),
179+
"Numeric operation length mismatch {}",
180+
rhs.encoding().id()
181+
);
182+
debug_assert_eq!(
183+
result.dtype(),
184+
&DType::Primitive(
185+
PType::try_from(lhs.dtype())
186+
.vortex_expect("Numeric operation DType failed to convert to PType"),
187+
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
188+
),
189+
"Numeric operation dtype mismatch {}",
190+
rhs.encoding().id()
191+
);
200192
}
201193

202194
#[cfg(feature = "test-harness")]

vortex-array/src/compute/compare.rs

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use vortex_dtype::{DType, Nullability};
66
use vortex_error::{vortex_bail, VortexError, VortexResult};
77
use vortex_scalar::Scalar;
88

9-
use crate::arrow::{Datum, FromArrowArray};
9+
use crate::arrow::{from_arrow_array_with_len, Datum};
1010
use crate::encoding::Encoding;
1111
use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData};
1212

@@ -130,18 +130,7 @@ pub fn compare(
130130
.and_then(|f| f.compare(left, right, operator).transpose())
131131
.transpose()?
132132
{
133-
debug_assert_eq!(
134-
result.len(),
135-
left.len(),
136-
"Compare length mismatch {}",
137-
left.encoding().id()
138-
);
139-
debug_assert_eq!(
140-
result.dtype(),
141-
&DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into()),
142-
"Compare dtype mismatch {}",
143-
left.encoding().id()
144-
);
133+
check_compare_result(&result, left, right);
145134
return Ok(result);
146135
}
147136

@@ -151,18 +140,7 @@ pub fn compare(
151140
.and_then(|f| f.compare(right, left, operator.swap()).transpose())
152141
.transpose()?
153142
{
154-
debug_assert_eq!(
155-
result.len(),
156-
left.len(),
157-
"Compare length mismatch {}",
158-
right.encoding().id()
159-
);
160-
debug_assert_eq!(
161-
result.dtype(),
162-
&result_dtype,
163-
"Compare dtype mismatch {}",
164-
right.encoding().id()
165-
);
143+
check_compare_result(&result, left, right);
166144
return Ok(result);
167145
}
168146

@@ -178,18 +156,20 @@ pub fn compare(
178156
}
179157

180158
// Fallback to arrow on canonical types
181-
arrow_compare(left, right, operator)
159+
let result = arrow_compare(left, right, operator)?;
160+
check_compare_result(&result, left, right);
161+
Ok(result)
182162
}
183163

184164
/// Implementation of `CompareFn` using the Arrow crate.
185-
pub(crate) fn arrow_compare(
186-
lhs: &ArrayData,
187-
rhs: &ArrayData,
165+
fn arrow_compare(
166+
left: &ArrayData,
167+
right: &ArrayData,
188168
operator: Operator,
189169
) -> VortexResult<ArrayData> {
190-
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
191-
let lhs = Datum::try_from(lhs.clone())?;
192-
let rhs = Datum::try_from(rhs.clone())?;
170+
let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
171+
let lhs = unsafe { Datum::try_new(left.clone())? };
172+
let rhs = unsafe { Datum::try_new(right.clone())? };
193173

194174
let array = match operator {
195175
Operator::Eq => cmp::eq(&lhs, &rhs)?,
@@ -199,8 +179,29 @@ pub(crate) fn arrow_compare(
199179
Operator::Lt => cmp::lt(&lhs, &rhs)?,
200180
Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
201181
};
182+
from_arrow_array_with_len(&array, left.len(), nullable)
183+
}
202184

203-
Ok(ArrayData::from_arrow(&array, nullable))
185+
#[inline(always)]
186+
fn check_compare_result(result: &ArrayData, lhs: &ArrayData, rhs: &ArrayData) {
187+
debug_assert_eq!(
188+
result.len(),
189+
lhs.len(),
190+
"CompareFn result length ({}) mismatch for left encoding {}, left len {}, right encoding {}, right len {}",
191+
result.len(),
192+
lhs.encoding().id(),
193+
lhs.len(),
194+
rhs.encoding().id(),
195+
rhs.len()
196+
);
197+
debug_assert_eq!(
198+
result.dtype(),
199+
&DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
200+
"CompareFn result dtype ({}) mismatch for left encoding {}, right encoding {}",
201+
result.dtype(),
202+
lhs.encoding().id(),
203+
rhs.encoding().id(),
204+
);
204205
}
205206

206207
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
@@ -312,7 +313,12 @@ mod tests {
312313
let left = ConstantArray::new(Scalar::from(2u32), 10);
313314
let right = ConstantArray::new(Scalar::from(10u32), 10);
314315

315-
let compare = compare(left, right, Operator::Gt).unwrap();
316+
let compare = compare(left.clone(), right.clone(), Operator::Gt).unwrap();
317+
let res = compare.as_constant().unwrap();
318+
assert_eq!(res.as_bool().value(), Some(false));
319+
assert_eq!(compare.len(), 10);
320+
321+
let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
316322
let res = compare.as_constant().unwrap();
317323
assert_eq!(res.as_bool().value(), Some(false));
318324
assert_eq!(compare.len(), 10);

0 commit comments

Comments
 (0)