Skip to content

Commit 965ed94

Browse files
fix[vortex-array]: check NaN for inferring is constant with min/max (#3663)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent b9aad17 commit 965ed94

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

vortex-array/src/compute/is_constant.rs

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,16 @@ fn is_constant_impl(
136136
}
137137

138138
// We already know here that the array is all valid, so we check for min/max stats.
139-
let min = array
140-
.statistics()
141-
.get_scalar(Stat::Min, array.dtype())
142-
.and_then(|p| p.as_exact());
143-
let max = array
144-
.statistics()
145-
.get_scalar(Stat::Max, array.dtype())
146-
.and_then(|p| p.as_exact());
139+
let min = array.statistics().get_scalar(Stat::Min, array.dtype());
140+
let max = array.statistics().get_scalar(Stat::Max, array.dtype());
147141

148142
if let Some((min, max)) = min.zip(max) {
149-
if min == max {
143+
// min/max are equal and exact and there are no NaNs
144+
if min.is_exact()
145+
&& min == max
146+
&& (Stat::NaNCount.dtype(array.dtype()).is_none()
147+
|| array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
148+
{
150149
return Ok(Some(true));
151150
}
152151
}
@@ -283,3 +282,43 @@ impl IsConstantOpts {
283282
self.cost == Cost::Negligible
284283
}
285284
}
285+
286+
#[cfg(test)]
287+
mod tests {
288+
use crate::arrays::PrimitiveArray;
289+
use crate::stats::Stat;
290+
291+
#[test]
292+
fn is_constant_min_max_no_nan() {
293+
let arr = PrimitiveArray::from_iter([0, 1]);
294+
arr.statistics()
295+
.compute_all(&[Stat::Min, Stat::Max])
296+
.unwrap();
297+
assert!(!arr.is_constant());
298+
299+
let arr = PrimitiveArray::from_iter([0, 0]);
300+
arr.statistics()
301+
.compute_all(&[Stat::Min, Stat::Max])
302+
.unwrap();
303+
assert!(arr.is_constant());
304+
305+
let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]);
306+
assert!(arr.is_constant());
307+
}
308+
309+
#[test]
310+
fn is_constant_min_max_with_nan() {
311+
let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]);
312+
arr.statistics()
313+
.compute_all(&[Stat::Min, Stat::Max])
314+
.unwrap();
315+
assert!(!arr.is_constant());
316+
317+
let arr =
318+
PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)]);
319+
arr.statistics()
320+
.compute_all(&[Stat::Min, Stat::Max])
321+
.unwrap();
322+
assert!(arr.is_constant());
323+
}
324+
}

0 commit comments

Comments
 (0)