Skip to content

Commit 19220f8

Browse files
bug[vortex-array]: take compare nullable bugs (#3628)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 386556b commit 19220f8

File tree

4 files changed

+94
-6
lines changed

4 files changed

+94
-6
lines changed

encodings/dict/src/compute/mod.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod like;
77
mod min_max;
88

99
use vortex_array::compute::{
10-
FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, filter, take,
10+
FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
1111
};
1212
use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
1313
use vortex_error::VortexResult;
@@ -17,8 +17,13 @@ use crate::{DictArray, DictVTable};
1717

1818
impl TakeKernel for DictVTable {
1919
fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20+
// TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
2021
let codes = take(array.codes(), indices)?;
21-
DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
22+
let values_dtype = array
23+
.values()
24+
.dtype()
25+
.union_nullability(codes.dtype().nullability());
26+
DictArray::try_new(codes, cast(array.values(), &values_dtype)?).map(|a| a.into_array())
2227
}
2328
}
2429

@@ -38,8 +43,9 @@ mod test {
3843
use vortex_array::accessor::ArrayAccessor;
3944
use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
4045
use vortex_array::compute::conformance::mask::test_mask;
41-
use vortex_array::compute::{Operator, compare};
46+
use vortex_array::compute::{Operator, compare, take};
4247
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
48+
use vortex_dtype::PType::I32;
4349
use vortex_dtype::{DType, Nullability};
4450
use vortex_scalar::Scalar;
4551

@@ -184,4 +190,19 @@ mod test {
184190
.unwrap();
185191
test_mask(array.as_ref());
186192
}
193+
194+
#[test]
195+
fn test_take_dict() {
196+
let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
197+
198+
assert_eq!(
199+
take(
200+
array.as_ref(),
201+
PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
202+
)
203+
.unwrap()
204+
.dtype(),
205+
&DType::Primitive(I32, Nullability::Nullable)
206+
);
207+
}
187208
}

encodings/fsst/src/compute/mod.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ impl TakeKernel for FSSTVTable {
1313
// Take on an FSSTArray is a simple take on the codes array.
1414
fn take(&self, array: &FSSTArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
1515
Ok(FSSTArray::try_new(
16-
array.dtype().clone(),
16+
array
17+
.dtype()
18+
.clone()
19+
.union_nullability(indices.dtype().nullability()),
1720
array.symbols().clone(),
1821
array.symbol_lengths().clone(),
1922
take(array.codes().as_ref(), indices)?
@@ -32,3 +35,33 @@ impl TakeKernel for FSSTVTable {
3235
}
3336

3437
register_kernel!(TakeKernelAdapter(FSSTVTable).lift());
38+
39+
#[cfg(test)]
40+
mod tests {
41+
use vortex_array::arrays::{PrimitiveArray, VarBinArray};
42+
use vortex_array::compute::take;
43+
use vortex_dtype::{DType, Nullability};
44+
45+
use crate::{fsst_compress, fsst_train_compressor};
46+
47+
#[test]
48+
fn test_take_null() {
49+
let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
50+
let compr = fsst_train_compressor(arr.as_ref()).unwrap();
51+
let fsst = fsst_compress(arr.as_ref(), &compr).unwrap();
52+
53+
let idx1: PrimitiveArray = (0..1).collect();
54+
55+
assert_eq!(
56+
take(fsst.as_ref(), idx1.as_ref()).unwrap().dtype(),
57+
&DType::Utf8(Nullability::NonNullable)
58+
);
59+
60+
let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
61+
62+
assert_eq!(
63+
take(fsst.as_ref(), idx2.as_ref()).unwrap().dtype(),
64+
&DType::Utf8(Nullability::Nullable)
65+
);
66+
}
67+
}

vortex-array/src/arrays/varbin/compute/compare.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ impl CompareKernel for VarBinVTable {
5252
};
5353

5454
return Ok(Some(
55-
BoolArray::new(buffer, lhs.validity().clone()).into_array(),
55+
BoolArray::new(
56+
buffer,
57+
lhs.validity()
58+
.clone()
59+
.union_nullability(rhs.dtype().nullability()),
60+
)
61+
.into_array(),
5662
));
5763
}
5864

@@ -177,3 +183,27 @@ mod test {
177183
);
178184
}
179185
}
186+
187+
#[cfg(test)]
188+
mod tests {
189+
use vortex_dtype::{DType, Nullability};
190+
use vortex_scalar::Scalar;
191+
192+
use crate::Array;
193+
use crate::arrays::{ConstantArray, VarBinArray};
194+
use crate::compute::{Operator, compare};
195+
196+
#[test]
197+
fn test_null_compare() {
198+
let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
199+
200+
let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
201+
202+
assert_eq!(
203+
compare(arr.as_ref(), const_.as_ref(), Operator::Eq)
204+
.unwrap()
205+
.dtype(),
206+
&DType::Bool(Nullability::Nullable)
207+
);
208+
}
209+
}

vortex-array/src/compute/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,14 @@ impl ComputeFn {
113113

114114
if output.dtype() != &expected_dtype {
115115
vortex_bail!(
116-
"Internal error: compute function {} returned a result of type {} but expected {}",
116+
"Internal error: compute function {} returned a result of type {} but expected {}\n{}",
117117
self.id,
118118
output.dtype(),
119119
&expected_dtype,
120+
args.inputs
121+
.iter()
122+
.filter_map(|input| input.array())
123+
.format_with(",", |array, f| f(&array.tree_display()))
120124
);
121125
}
122126
if output.len() != expected_len {

0 commit comments

Comments
 (0)