Skip to content

Commit d17d830

Browse files
authored
list is_constant (#3637)
fixes #3622 --------- Signed-off-by: Onur Satici <[email protected]>
1 parent fe08571 commit d17d830

File tree

3 files changed

+200
-18
lines changed

3 files changed

+200
-18
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
use vortex_error::VortexResult;
2+
use vortex_scalar::NumericOperator;
3+
4+
use crate::arrays::{ListArray, ListVTable};
5+
use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, numeric};
6+
use crate::register_kernel;
7+
8+
const SMALL_ARRAY_THRESHOLD: usize = 64;
9+
10+
impl IsConstantKernel for ListVTable {
11+
fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
12+
// At this point, we're guaranteed:
13+
// - Array has at least 2 elements
14+
// - All elements are valid (no nulls)
15+
16+
let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
17+
18+
let first_list_len = array.offset_at(1) - array.offset_at(0);
19+
for i in 1..manual_check_until {
20+
let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
21+
if current_list_len != first_list_len {
22+
return Ok(Some(false));
23+
}
24+
}
25+
26+
if opts.is_negligible_cost() {
27+
return Ok(None);
28+
}
29+
30+
if array.len() > SMALL_ARRAY_THRESHOLD {
31+
// check the rest of the element lengths
32+
let start_offsets = array.offsets.slice(SMALL_ARRAY_THRESHOLD, array.len())?;
33+
let end_offsets = array
34+
.offsets
35+
.slice(SMALL_ARRAY_THRESHOLD + 1, array.len() + 1)?;
36+
let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
37+
38+
if !list_lengths.is_constant() {
39+
return Ok(Some(false));
40+
}
41+
}
42+
43+
// If all lists have the same length, compare the actual list contents
44+
let first_scalar = array.scalar_at(0)?;
45+
for i in 1..array.len() {
46+
let current_scalar = array.scalar_at(i)?;
47+
if current_scalar != first_scalar {
48+
return Ok(Some(false));
49+
}
50+
}
51+
52+
Ok(Some(true))
53+
}
54+
}
55+
56+
register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
57+
58+
#[cfg(test)]
59+
mod tests {
60+
61+
use rstest::rstest;
62+
use vortex_dtype::FieldNames;
63+
64+
use crate::IntoArray;
65+
use crate::arrays::{ListArray, PrimitiveArray, StructArray};
66+
use crate::compute::is_constant;
67+
use crate::validity::Validity;
68+
69+
#[test]
70+
fn test_is_constant_nested_list() {
71+
let xs = ListArray::try_new(
72+
PrimitiveArray::from_iter([0i32, 1, 0, 1]).into_array(),
73+
PrimitiveArray::from_iter([0u32, 2, 4]).into_array(),
74+
Validity::NonNullable,
75+
)
76+
.unwrap();
77+
78+
let struct_of_lists = StructArray::try_new(
79+
FieldNames::from(["xs".into()]),
80+
vec![xs.into_array()],
81+
2,
82+
Validity::NonNullable,
83+
)
84+
.unwrap();
85+
assert!(
86+
is_constant(&struct_of_lists.clone().into_array())
87+
.unwrap()
88+
.unwrap()
89+
);
90+
assert!(struct_of_lists.is_constant());
91+
}
92+
93+
#[rstest]
94+
#[case(
95+
// [1,2], [1, 2], [1, 2]
96+
vec![1i32, 2, 1, 2, 1, 2],
97+
vec![0u32, 2, 4, 6],
98+
true
99+
)]
100+
#[case(
101+
// [1, 2], [3], [4, 5]
102+
vec![1i32, 2, 3, 4, 5],
103+
vec![0u32, 2, 3, 5],
104+
false
105+
)]
106+
#[case(
107+
// [1, 2], [3, 4]
108+
vec![1i32, 2, 3, 4],
109+
vec![0u32, 2, 4],
110+
false
111+
)]
112+
#[case(
113+
// [], [], []
114+
vec![],
115+
vec![0u32, 0, 0, 0],
116+
true
117+
)]
118+
fn test_list_is_constant(
119+
#[case] elements: Vec<i32>,
120+
#[case] offsets: Vec<u32>,
121+
#[case] expected: bool,
122+
) {
123+
let list_array = ListArray::try_new(
124+
PrimitiveArray::from_iter(elements).into_array(),
125+
PrimitiveArray::from_iter(offsets).into_array(),
126+
Validity::NonNullable,
127+
)
128+
.unwrap();
129+
130+
let result = is_constant(&list_array.into_array()).unwrap();
131+
assert_eq!(result.unwrap(), expected);
132+
}
133+
134+
#[test]
135+
fn test_list_is_constant_nested_lists() {
136+
let inner_elements = PrimitiveArray::from_iter([1i32, 2, 1, 2]).into_array();
137+
let inner_offsets = PrimitiveArray::from_iter([0u32, 1, 2, 3, 4]).into_array();
138+
let inner_lists =
139+
ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
140+
141+
let outer_offsets = PrimitiveArray::from_iter([0u32, 2, 4]).into_array();
142+
let outer_list = ListArray::try_new(
143+
inner_lists.into_array(),
144+
outer_offsets,
145+
Validity::NonNullable,
146+
)
147+
.unwrap();
148+
149+
// Both outer lists contain [[1], [2]], so should be constant
150+
assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
151+
}
152+
153+
#[rstest]
154+
#[case(
155+
// 100 identical [1, 2] lists
156+
[1i32, 2].repeat(100),
157+
(0..101).map(|i| (i * 2) as u32).collect(),
158+
true
159+
)]
160+
#[case(
161+
// Difference after threshold: 64 identical [1, 2] + one [3, 4]
162+
{
163+
let mut elements = [1i32, 2].repeat(64);
164+
elements.extend_from_slice(&[3, 4]);
165+
elements
166+
},
167+
(0..66).map(|i| (i * 2) as u32).collect(),
168+
false
169+
)]
170+
#[case(
171+
// Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
172+
{
173+
let mut elements = [1i32, 2].repeat(63);
174+
elements.extend_from_slice(&[3, 4]);
175+
elements.extend([1i32, 2].repeat(37));
176+
elements
177+
},
178+
(0..101).map(|i| (i * 2) as u32).collect(),
179+
false
180+
)]
181+
fn test_large_list_is_constant(
182+
#[case] elements: Vec<i32>,
183+
#[case] offsets: Vec<u32>,
184+
#[case] expected: bool,
185+
) {
186+
let list_array = ListArray::try_new(
187+
PrimitiveArray::from_iter(elements).into_array(),
188+
PrimitiveArray::from_iter(offsets).into_array(),
189+
Validity::NonNullable,
190+
)
191+
.unwrap();
192+
193+
let result = is_constant(&list_array.into_array()).unwrap();
194+
assert_eq!(result.unwrap(), expected);
195+
}
196+
}

vortex-array/src/arrays/list/compute/mod.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,15 @@
11
mod filter;
2+
mod is_constant;
23
mod mask;
34

45
use vortex_error::VortexResult;
56

67
use crate::arrays::{ListArray, ListVTable};
78
use crate::compute::{
8-
IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, IsSortedKernel,
9-
IsSortedKernelAdapter, MinMaxKernel, MinMaxKernelAdapter, MinMaxResult,
9+
IsSortedKernel, IsSortedKernelAdapter, MinMaxKernel, MinMaxKernelAdapter, MinMaxResult,
1010
};
1111
use crate::register_kernel;
1212

13-
impl IsConstantKernel for ListVTable {
14-
fn is_constant(
15-
&self,
16-
_array: &ListArray,
17-
_opts: &IsConstantOpts,
18-
) -> VortexResult<Option<bool>> {
19-
// TODO(adam): Do we want to fallback to arrow here?
20-
Ok(None)
21-
}
22-
}
23-
24-
register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
25-
2613
impl MinMaxKernel for ListVTable {
2714
fn min_max(&self, _array: &ListArray) -> VortexResult<Option<MinMaxResult>> {
2815
// TODO(joe): Implement list min max

vortex-array/src/compute/is_constant.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
77
use vortex_scalar::Scalar;
88

99
use crate::Array;
10-
use crate::arrays::{ConstantVTable, ListVTable, NullVTable};
10+
use crate::arrays::{ConstantVTable, NullVTable};
1111
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
1212
use crate::stats::{Precision, Stat, StatsProviderExt};
1313
use crate::vtable::VTable;
@@ -72,8 +72,7 @@ impl ComputeFnVTable for IsConstant {
7272

7373
let value = is_constant_impl(array, options, kernels)?;
7474

75-
// TODO(joe): add is_constant for ListArray
76-
if options.cost == Cost::Canonicalize && !array.is::<ListVTable>() {
75+
if options.cost == Cost::Canonicalize {
7776
// When we run linear canonicalize, there we must always return an exact answer.
7877
assert!(
7978
value.is_some(),

0 commit comments

Comments
 (0)