Skip to content
Merged
116 changes: 116 additions & 0 deletions src/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,33 @@ where
where
A: PartialOrd;

/// Finds the index of the minimum value of the array skipping NaN values.
///
/// Returns `None` if the array is empty or none of the values in the array
/// are non-NaN values.
///
/// Even if there are multiple (equal) elements that are minima, only one
/// index is returned. (Which one is returned is unspecified and may depend
/// on the memory layout of the array.)
///
/// # Example
///
/// ```
/// extern crate ndarray;
/// extern crate ndarray_stats;
///
/// use ndarray::array;
/// use ndarray_stats::QuantileExt;
///
/// let a = array![[::std::f64::NAN, 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmin_skipnan(), Some((1, 1)));
/// ```
fn argmin_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord;

/// Finds the elementwise minimum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
Expand Down Expand Up @@ -269,6 +296,33 @@ where
where
A: PartialOrd;

/// Finds the index of the maximum value of the array skipping NaN values.
///
/// Returns `None` if the array is empty or none of the values in the array
/// are non-NaN values.
///
/// Even if there are multiple (equal) elements that are maxima, only one
/// index is returned. (Which one is returned is unspecified and may depend
/// on the memory layout of the array.)
///
/// # Example
///
/// ```
/// extern crate ndarray;
/// extern crate ndarray_stats;
///
/// use ndarray::array;
/// use ndarray_stats::QuantileExt;
///
/// let a = array![[::std::f64::NAN, 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmax_skipnan(), Some((1, 2)));
/// ```
fn argmax_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord;

/// Finds the elementwise maximum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
Expand Down Expand Up @@ -369,6 +423,37 @@ where
Some(current_pattern_min)
}

fn argmin_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord,
{
let first = self.first().and_then(|v| v.try_as_not_nan());
let mut pattern_min = D::zeros(self.ndim()).into_pattern();

let min = self
.indexed_iter()
.fold(first, |current_min, (pattern, elem)| {
let elem_not_nan = elem.try_as_not_nan();

if elem_not_nan.is_some()
&& (current_min.is_none()
|| elem_not_nan.cmp(&current_min) == cmp::Ordering::Less)
{
pattern_min = pattern;
elem_not_nan
} else {
current_min
}
});

if min == None {
None
} else {
Some(pattern_min)
}
}

fn min(&self) -> Option<&A>
where
A: PartialOrd,
Expand Down Expand Up @@ -411,6 +496,37 @@ where
Some(current_pattern_max)
}

fn argmax_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord,
{
let first = self.first().and_then(|v| v.try_as_not_nan());
let mut pattern_max = D::zeros(self.ndim()).into_pattern();

let max = self
.indexed_iter()
.fold(first, |current_max, (pattern, elem)| {
let elem_not_nan = elem.try_as_not_nan();

if elem_not_nan.is_some()
&& (current_max.is_none()
|| elem_not_nan.cmp(&current_max) == cmp::Ordering::Greater)
{
pattern_max = pattern;
elem_not_nan
} else {
current_max
}
});

if max == None {
None
} else {
Some(pattern_max)
}
}

fn max(&self) -> Option<&A>
where
A: PartialOrd,
Expand Down
53 changes: 53 additions & 0 deletions tests/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@ quickcheck! {
}
}

#[test]
fn test_argmin_skipnan() {
let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmin_skipnan(), Some((1, 1)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin_skipnan(), Some((0, 0)));

let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin_skipnan(), Some((1, 0)));

let a: Array2<f64> = array![[], []];
assert_eq!(a.argmin_skipnan(), None);

let a = arr2(&[[::std::f64::NAN; 2]; 2]);
assert_eq!(a.argmin_skipnan(), None);
}

quickcheck! {
fn argmin_skipnan_matches_min(data: Vec<f32>) -> bool {
let a = Array1::from(data);
a.argmin_skipnan().map(|i| a[i]) == a.min().cloned()
}
}

#[test]
fn test_min() {
let a = array![[1, 5, 3], [2, 0, 6]];
Expand Down Expand Up @@ -81,6 +106,34 @@ quickcheck! {
}
}

#[test]
fn test_argmax_skipnan() {
let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmax_skipnan(), Some((1, 2)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]];
assert_eq!(a.argmax_skipnan(), Some((0, 1)));

let a = array![
[::std::f64::NAN, ::std::f64::NAN, 3.],
[2., ::std::f64::NAN, 6.]
];
assert_eq!(a.argmax_skipnan(), Some((1, 2)));

let a: Array2<f64> = array![[], []];
assert_eq!(a.argmax_skipnan(), None);

let a = arr2(&[[::std::f64::NAN; 2]; 2]);
assert_eq!(a.argmax_skipnan(), None);
}

quickcheck! {
fn argmax_skipnan_matches_max(data: Vec<f32>) -> bool {
let a = Array1::from(data);
a.argmax_skipnan().map(|i| a[i]) == a.max().cloned()
}
}

#[test]
fn test_max() {
let a = array![[1, 5, 7], [2, 0, 6]];
Expand Down