-
Notifications
You must be signed in to change notification settings - Fork 21
More stable algorithm for variance, standard deviation #456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 42 commits
0f29529
1fbf5f8
322f511
adab8e6
93cd9b3
2be4f74
edb655d
dd2e4b6
936ed1d
1968870
d036ebc
12bcb0f
6f5bece
b1f7b5d
cd9a8b8
27448e4
10214cc
a81b1a3
004fddc
4491ce9
c3a6d88
4dcd7c2
c101a2b
98e1b4e
d0d09df
1139a9c
569629c
50ad095
f88e231
77526fd
0f5d587
31f30c9
3b3369f
24fb532
177b8de
7deb84a
120fbf3
4541c46
aa4b9b3
d5c59e3
b721433
4f26ed8
d77c132
3cbe54c
d7d772c
9a51095
1373318
4f15495
591997c
bbc0be2
63d7e96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
from . import aggregate_flox, aggregate_npg, xrutils | ||
from . import xrdtypes as dtypes | ||
from .lib import dask_array_type, sparse_array_type | ||
from .multiarray import MultiArray | ||
from .xrutils import notnull | ||
|
||
if TYPE_CHECKING: | ||
FuncTuple = tuple[Callable | str, ...] | ||
|
@@ -161,8 +163,8 @@ def __init__( | |
self, | ||
name: str, | ||
*, | ||
numpy: str | None = None, | ||
chunk: str | FuncTuple | None, | ||
numpy: partial | str | None = None, | ||
chunk: partial | str | FuncTuple | None, | ||
combine: str | FuncTuple | None, | ||
preprocess: Callable | None = None, | ||
finalize: Callable | None = None, | ||
|
@@ -343,57 +345,171 @@ def _mean_finalize(sum_, count): | |
) | ||
|
||
|
||
# TODO: fix this for complex numbers | ||
def _var_finalize(sumsq, sum_, count, ddof=0): | ||
with np.errstate(invalid="ignore", divide="ignore"): | ||
result = (sumsq - (sum_**2 / count)) / (count - ddof) | ||
result[count <= ddof] = np.nan | ||
return result | ||
def var_chunk( | ||
group_idx, array, *, skipna: bool, engine: str, axis=-1, size=None, fill_value=None, dtype=None | ||
): | ||
# Calculate length and sum - important for the adjustment terms to sum squared deviations | ||
array_lens = generic_aggregate( | ||
group_idx, | ||
array, | ||
func="nanlen", | ||
engine=engine, | ||
axis=axis, | ||
size=size, | ||
fill_value=0, # Unpack fill value bc it's currently defined for multiarray | ||
dtype=dtype, | ||
) | ||
|
||
array_sums = generic_aggregate( | ||
group_idx, | ||
array, | ||
func="nansum" if skipna else "sum", | ||
engine=engine, | ||
axis=axis, | ||
size=size, | ||
fill_value=0, # Unpack fill value bc it's currently defined for multiarray | ||
dtype=dtype, | ||
) | ||
|
||
# Calculate sum squared deviations - the main part of variance sum | ||
array_means = array_sums / array_lens | ||
|
||
sum_squared_deviations = generic_aggregate( | ||
group_idx, | ||
(array - array_means[..., group_idx]) ** 2, | ||
func="nansum" if skipna else "sum", | ||
engine=engine, | ||
axis=axis, | ||
size=size, | ||
fill_value=0, # Unpack fill value bc it's currently defined for multiarray | ||
dtype=dtype, | ||
) | ||
|
||
def _std_finalize(sumsq, sum_, count, ddof=0): | ||
return np.sqrt(_var_finalize(sumsq, sum_, count, ddof)) | ||
return MultiArray((sum_squared_deviations, array_sums, array_lens)) | ||
|
||
|
||
def _var_combine(array, axis, keepdims=True): | ||
def clip_last(array, n=1): | ||
"""Return array except the last element along axis | ||
Purely included to tidy up the adj_terms line | ||
""" | ||
assert n > 0, "Clipping nothing off the end isn't implemented" | ||
not_last = [slice(None, None) for i in range(array.ndim)] | ||
not_last[axis[0]] = slice(None, -n) | ||
return array[*not_last] | ||
|
||
def clip_first(array, n=1): | ||
"""Return array except the first element along axis | ||
Purely included to tidy up the adj_terms line | ||
""" | ||
not_first = [slice(None, None) for i in range(array.ndim)] | ||
not_first[axis[0]] = slice(n, None) | ||
return array[*not_first] | ||
|
||
assert len(axis) == 1, "Assuming that the combine function is only in one direction at once" | ||
|
||
sum_deviations, sum_X, sum_len = array.arrays | ||
|
||
# Calculate parts needed for cascading combination | ||
cumsum_X = np.cumsum(sum_X, axis=axis[0]) # Don't need to be able to merge the last element | ||
cumsum_len = np.cumsum(sum_len, axis=axis[0]) | ||
|
||
# There will be instances in which one or both chunks being merged are empty | ||
# In which case, the adjustment term should be zero, but will throw a divide-by-zero error | ||
# We're going to add a constant to the bottom of the adjustment term equation on those instances | ||
# and count on the zeros on the top making our adjustment term still zero | ||
zero_denominator = (clip_last(cumsum_len) == 0) | (clip_first(sum_len) == 0) | ||
|
||
# Adjustment terms to tweak the sum of squared deviations because not every chunk has the same mean | ||
adj_terms = ( | ||
clip_last(cumsum_len) * clip_first(sum_X) - clip_first(sum_len) * clip_last(cumsum_X) | ||
) ** 2 / ( | ||
clip_last(cumsum_len) * clip_first(sum_len) * (clip_last(cumsum_len) + clip_first(sum_len)) | ||
+ zero_denominator.astype(int) | ||
) | ||
|
||
check = adj_terms * zero_denominator | ||
assert np.all(check[notnull(check)] == 0), ( | ||
"Instances where we add something to the denominator must come out to zero" | ||
) | ||
|
||
return MultiArray( | ||
( | ||
np.sum(sum_deviations, axis=axis, keepdims=keepdims) | ||
+ np.sum(adj_terms, axis=axis, keepdims=keepdims), # sum of squared deviations | ||
np.sum(sum_X, axis=axis, keepdims=keepdims), # sum of array items | ||
np.sum(sum_len, axis=axis, keepdims=keepdims), # sum of array lengths | ||
) | ||
) | ||
|
||
|
||
def is_var_chunk_reduction(agg: Callable) -> bool: | ||
if isinstance(agg, partial): | ||
agg = agg.func | ||
return agg is blockwise_or_numpy_var or agg is var_chunk | ||
|
||
|
||
def _var_finalize(multiarray, ddof=0): | ||
den = multiarray.arrays[2] - ddof | ||
# preserve nans for groups with 0 obs; so these values are -ddof | ||
den[den < 0] = 0 | ||
return multiarray.arrays[0] / den | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I correct that this will throw a divide by zero warning for groups with zero obs? Is that the intended behaviour? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I was relying on it to set NaNs; but you're right; it's probably better to use a mask |
||
|
||
|
||
def _std_finalize(multiarray, ddof=0): | ||
return np.sqrt(_var_finalize(multiarray, ddof)) | ||
|
||
|
||
def blockwise_or_numpy_var(*args, skipna: bool, ddof=0, std=False, **kwargs): | ||
res = _var_finalize(var_chunk(*args, skipna=skipna, **kwargs), ddof) | ||
return np.sqrt(res) if std else res | ||
|
||
|
||
# var, std always promote to float, so we set nan | ||
var = Aggregation( | ||
"var", | ||
chunk=("sum_of_squares", "sum", "nanlen"), | ||
combine=("sum", "sum", "sum"), | ||
chunk=partial(var_chunk, skipna=False), | ||
numpy=partial(blockwise_or_numpy_var, skipna=False), | ||
combine=(_var_combine,), | ||
finalize=_var_finalize, | ||
fill_value=0, | ||
fill_value=((0, 0, 0),), | ||
final_fill_value=np.nan, | ||
dtypes=(None, None, np.intp), | ||
dtypes=(None,), | ||
final_dtype=np.floating, | ||
) | ||
|
||
nanvar = Aggregation( | ||
"nanvar", | ||
chunk=("nansum_of_squares", "nansum", "nanlen"), | ||
combine=("sum", "sum", "sum"), | ||
chunk=partial(var_chunk, skipna=True), | ||
numpy=partial(blockwise_or_numpy_var, skipna=True), | ||
combine=(_var_combine,), | ||
finalize=_var_finalize, | ||
fill_value=0, | ||
fill_value=((0, 0, 0),), | ||
final_fill_value=np.nan, | ||
dtypes=(None, None, np.intp), | ||
dtypes=(None,), | ||
final_dtype=np.floating, | ||
) | ||
|
||
std = Aggregation( | ||
"std", | ||
chunk=("sum_of_squares", "sum", "nanlen"), | ||
combine=("sum", "sum", "sum"), | ||
chunk=partial(var_chunk, skipna=False), | ||
numpy=partial(blockwise_or_numpy_var, skipna=False, std=True), | ||
combine=(_var_combine,), | ||
finalize=_std_finalize, | ||
fill_value=0, | ||
fill_value=((0, 0, 0),), | ||
final_fill_value=np.nan, | ||
dtypes=(None, None, np.intp), | ||
dtypes=(None,), | ||
final_dtype=np.floating, | ||
) | ||
nanstd = Aggregation( | ||
"nanstd", | ||
chunk=("nansum_of_squares", "nansum", "nanlen"), | ||
combine=("sum", "sum", "sum"), | ||
chunk=partial(var_chunk, skipna=True), | ||
numpy=partial(blockwise_or_numpy_var, skipna=True, std=True), | ||
combine=(_var_combine,), | ||
finalize=_std_finalize, | ||
fill_value=0, | ||
fill_value=((0, 0, 0),), | ||
final_fill_value=np.nan, | ||
dtypes=(None, None, np.intp), | ||
dtypes=(None,), | ||
final_dtype=np.floating, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from collections.abc import Callable | ||
from typing import Self | ||
|
||
import numpy as np | ||
|
||
MULTIARRAY_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} | ||
|
||
|
||
class MultiArray: | ||
arrays: tuple[np.ndarray, ...] | ||
|
||
def __init__(self, arrays): | ||
self.arrays = arrays | ||
assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape" | ||
|
||
def astype(self, dt, **kwargs) -> Self: | ||
return type(self)(tuple(array.astype(dt, **kwargs) for array in self.arrays)) | ||
|
||
def reshape(self, shape, **kwargs) -> Self: | ||
return type(self)(tuple(array.reshape(shape, **kwargs) for array in self.arrays)) | ||
|
||
def squeeze(self, axis=None) -> Self: | ||
return type(self)(tuple(array.squeeze(axis) for array in self.arrays)) | ||
|
||
def __setitem__(self, key, value) -> None: | ||
assert len(value) == len(self.arrays) | ||
for array, val in zip(self.arrays, value): | ||
array[key] = val | ||
|
||
def __array_function__(self, func, types, args, kwargs): | ||
if func not in MULTIARRAY_HANDLED_FUNCTIONS: | ||
return NotImplemented | ||
# Note: this allows subclasses that don't override | ||
# __array_function__ to handle MyArray objects | ||
# if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in? | ||
# return NotImplemented | ||
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs) | ||
|
||
# Shape is needed, seems likely that the other two might be | ||
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this | ||
@property | ||
def dtype(self) -> np.dtype: | ||
return self.arrays[0].dtype | ||
|
||
@property | ||
def shape(self) -> tuple[int, ...]: | ||
return self.arrays[0].shape | ||
|
||
@property | ||
def ndim(self) -> int: | ||
return self.arrays[0].ndim | ||
|
||
def __getitem__(self, key) -> Self: | ||
return type(self)([array[key] for array in self.arrays]) | ||
|
||
|
||
def implements(numpy_function): | ||
"""Register an __array_function__ implementation for MyArray objects.""" | ||
|
||
def decorator(func): | ||
MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func | ||
return func | ||
|
||
return decorator | ||
|
||
|
||
@implements(np.expand_dims) | ||
def expand_dims(multiarray, axis) -> MultiArray: | ||
return MultiArray(tuple(np.expand_dims(a, axis) for a in multiarray.arrays)) | ||
|
||
|
||
@implements(np.concatenate) | ||
def concatenate(multiarrays, axis) -> MultiArray: | ||
n_arrays = len(multiarrays[0].arrays) | ||
for ma in multiarrays[1:]: | ||
assert len(ma.arrays) == n_arrays | ||
return MultiArray( | ||
tuple(np.concatenate(tuple(ma.arrays[i] for ma in multiarrays), axis) for i in range(n_arrays)) | ||
) | ||
|
||
|
||
@implements(np.transpose) | ||
def transpose(multiarray, axes) -> MultiArray: | ||
return MultiArray(tuple(np.transpose(a, axes) for a in multiarray.arrays)) | ||
|
||
|
||
@implements(np.full) | ||
def full(shape, fill_values, *args, **kwargs) -> MultiArray: | ||
"""All arguments except fill_value are shared by each array in the MultiArray. | ||
Iterate over fill_values to create arrays | ||
""" | ||
return MultiArray(tuple(np.full(shape, fv, *args, **kwargs) for fv in fill_values)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,6 +146,9 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool: | |
|
||
|
||
def notnull(data): | ||
if isinstance(data, tuple) and len(data) == 3 and data == (0, 0, 0): | ||
# boo: another special case for Var | ||
return True | ||
if not is_duck_array(data): | ||
data = np.asarray(data) | ||
|
||
|
@@ -163,6 +166,9 @@ def notnull(data): | |
|
||
|
||
def isnull(data: Any): | ||
if isinstance(data, tuple) and len(data) == 3 and data == (0, 0, 0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, what are these lines (and associated lines above) doing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it didn't like the tuple of 0s, so it's a hack |
||
# boo: another special case for Var | ||
return False | ||
if data is None: | ||
return False | ||
if not is_duck_array(data): | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -108,9 +108,8 @@ def insert_nans(draw: st.DrawFn, array: np.ndarray) -> np.ndarray: | |||
"any", | ||||
"all", | ||||
] + list(SCIPY_STATS_FUNCS) | ||||
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a real win! We now add var and std to our property test suite ( Line 91 in cbcc035
|
||||
|
||||
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]) | ||||
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS]) | ||||
|
||||
|
||||
@st.composite | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I understanding correctly that this would overwrite whatever is passed through in fill_value when the aggregation is defined? And we're assuming that in no instance would a different value of fill_value be wanted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the concern is None[2] isn't a thing wouldn't it make more sense to have (None, None, None) be the default and keep the unpacking?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think the hardcoding is fine here. It's probably fine to just set
fill_value=(np.nan,)