Skip to content

Commit 5486c92

Browse files
authored
fix: Dict LikeFn length mismatch (#2043)
before this PR, LikeFn takes a `pattern: &ArrayData` argument, but for the DictArray implementation, it would only (accidentally) return a correct result if `pattern` was a `ConstantArray`
1 parent 047228b commit 5486c92

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

encodings/dict/src/compute/like.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use vortex_array::array::ConstantArray;
12
use vortex_array::compute::{like, LikeFn, LikeOptions};
23
use vortex_array::{ArrayData, IntoArrayData};
34
use vortex_error::VortexResult;
@@ -10,8 +11,15 @@ impl LikeFn<DictArray> for DictEncoding {
1011
array: DictArray,
1112
pattern: &ArrayData,
1213
options: LikeOptions,
13-
) -> VortexResult<ArrayData> {
14-
let values = like(array.values(), pattern, options)?;
15-
Ok(DictArray::try_new(array.codes(), values)?.into_array())
14+
) -> VortexResult<Option<ArrayData>> {
15+
if let Some(pattern) = pattern.as_constant() {
16+
let pattern = ConstantArray::new(pattern, array.values().len()).into_array();
17+
let values = like(array.values(), &pattern, options)?;
18+
Ok(Some(
19+
DictArray::try_new(array.codes(), values)?.into_array(),
20+
))
21+
} else {
22+
Ok(None)
23+
}
1624
}
1725
}

vortex-array/src/compute/like.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub trait LikeFn<Array> {
1111
array: Array,
1212
pattern: &ArrayData,
1313
options: LikeOptions,
14-
) -> VortexResult<ArrayData>;
14+
) -> VortexResult<Option<ArrayData>>;
1515
}
1616

1717
impl<E: Encoding> LikeFn<ArrayData> for E
@@ -24,7 +24,7 @@ where
2424
array: ArrayData,
2525
pattern: &ArrayData,
2626
options: LikeOptions,
27-
) -> VortexResult<ArrayData> {
27+
) -> VortexResult<Option<ArrayData>> {
2828
let encoding = array
2929
.encoding()
3030
.as_any()
@@ -58,20 +58,32 @@ pub fn like(
5858
if !matches!(pattern.dtype(), DType::Utf8(..)) {
5959
vortex_bail!("Expected utf8 pattern, got {}", array.dtype());
6060
}
61+
if array.len() != pattern.len() {
62+
vortex_bail!(
63+
"Length mismatch lhs len {} ({}) != rhs len {} ({})",
64+
array.len(),
65+
array.encoding().id(),
66+
pattern.len(),
67+
pattern.encoding().id()
68+
);
69+
}
70+
6171
let expected_dtype =
6272
DType::Bool((array.dtype().is_nullable() || pattern.dtype().is_nullable()).into());
6373
let array_encoding = array.encoding().id();
6474

65-
let result = if let Some(f) = array.encoding().like_fn() {
66-
f.like(array, pattern, options)
67-
} else {
68-
// Otherwise, we canonicalize into a UTF8 array.
69-
log::debug!(
70-
"No like implementation found for encoding {}",
71-
array.encoding().id(),
72-
);
73-
arrow_like(array, pattern, options)
74-
}?;
75+
let result = array
76+
.encoding()
77+
.like_fn()
78+
.and_then(|f| f.like(array.clone(), pattern, options).transpose())
79+
.unwrap_or_else(|| {
80+
// Otherwise, we canonicalize into a UTF8 array.
81+
log::debug!(
82+
"No like implementation found for encoding {}",
83+
array.encoding().id(),
84+
);
85+
arrow_like(array, pattern, options)
86+
})?;
7587

7688
debug_assert_eq!(
7789
result.len(),
@@ -97,6 +109,12 @@ pub(crate) fn arrow_like(
97109
) -> VortexResult<ArrayData> {
98110
let nullable = array.dtype().is_nullable();
99111
let len = array.len();
112+
debug_assert_eq!(
113+
array.len(),
114+
pattern.len(),
115+
"Arrow Like: length mismatch for {}",
116+
array.encoding().id()
117+
);
100118
let lhs = unsafe { Datum::try_new(array)? };
101119
let rhs = unsafe { Datum::try_new(pattern.clone())? };
102120

0 commit comments

Comments
 (0)