Skip to content

Commit 24fb532

Browse files
committed
support var, std
1 parent 3b3369f commit 24fb532

File tree

3 files changed

+40
-39
lines changed

3 files changed

+40
-39
lines changed

flox/aggregations.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ def _mean_finalize(sum_, count):
343343
)
344344

345345

346-
def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=None, dtype=None):
346+
def var_chunk(
347+
group_idx, array, *, skipna: bool, engine: str, axis=-1, size=None, fill_value=None, dtype=None
348+
):
347349
from .aggregate_flox import MultiArray
348350

349351
# Calculate length and sum - important for the adjustment terms to sum squared deviations
@@ -361,7 +363,7 @@ def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=N
361363
array_sums = generic_aggregate(
362364
group_idx,
363365
array,
364-
func="nansum",
366+
func="nansum" if skipna else "sum",
365367
engine=engine,
366368
axis=axis,
367369
size=size,
@@ -375,7 +377,7 @@ def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=N
375377
sum_squared_deviations = generic_aggregate(
376378
group_idx,
377379
(array - array_means[..., group_idx]) ** 2,
378-
func="nansum",
380+
func="nansum" if skipna else "sum",
379381
engine=engine,
380382
axis=axis,
381383
size=size,
@@ -448,73 +450,73 @@ def clip_first(array, n=1):
448450
# return result
449451

450452

453+
def is_var_chunk_reduction(agg: Callable) -> bool:
454+
if isinstance(agg, partial):
455+
agg = agg.func
456+
return agg is blockwise_or_numpy_var or agg is var_chunk
457+
458+
451459
def _var_finalize(multiarray, ddof=0):
452460
den = multiarray.arrays[2] - ddof
453461
# preserve nans for groups with 0 obs; so these values are -ddof
454462
den[den < 0] = 0
455463
return multiarray.arrays[0] / den
456464

457465

458-
def _std_finalize(sumsq, sum_, count, ddof=0):
459-
return np.sqrt(_var_finalize(sumsq, sum_, count, ddof))
466+
def _std_finalize(multiarray, ddof=0):
467+
return np.sqrt(_var_finalize(multiarray, ddof))
468+
469+
470+
def blockwise_or_numpy_var(*args, skipna: bool, ddof=0, std=False, **kwargs):
471+
res = _var_finalize(var_chunk(*args, skipna=skipna, **kwargs), ddof)
472+
return np.sqrt(res) if std else res
460473

461474

462475
# var, std always promote to float, so we set nan
463476
var = Aggregation(
464477
"var",
465-
chunk=("sum_of_squares", "sum", "nanlen"),
466-
combine=("sum", "sum", "sum"),
478+
chunk=partial(var_chunk, skipna=False),
479+
numpy=partial(blockwise_or_numpy_var, skipna=False),
480+
combine=(_var_combine,),
467481
finalize=_var_finalize,
468-
fill_value=0,
482+
fill_value=((0, 0, 0),),
469483
final_fill_value=np.nan,
470-
dtypes=(None, None, np.intp),
484+
dtypes=(None,),
471485
final_dtype=np.floating,
472486
)
473-
# nanvar = Aggregation(
474-
# "nanvar",
475-
# chunk=("nansum_of_squares", "nansum", "nanlen"),
476-
# combine=("sum", "sum", "sum"),
477-
# finalize=_var_finalize,
478-
# fill_value=0,
479-
# final_fill_value=np.nan,
480-
# dtypes=(None, None, np.intp),
481-
# final_dtype=np.floating,
482-
# )
483-
484-
485-
def blockwise_or_numpy_var(*args, ddof=0, **kwargs):
486-
return _var_finalize(var_chunk(*args, **kwargs), ddof)
487-
488487

489488
nanvar = Aggregation(
490489
"nanvar",
491-
chunk=var_chunk,
492-
numpy=blockwise_or_numpy_var,
490+
chunk=partial(var_chunk, skipna=True),
491+
numpy=partial(blockwise_or_numpy_var, skipna=True),
493492
combine=(_var_combine,),
494493
finalize=_var_finalize,
495494
fill_value=((0, 0, 0),),
496495
final_fill_value=np.nan,
497496
dtypes=(None,),
498497
final_dtype=np.floating,
499498
)
499+
500500
std = Aggregation(
501501
"std",
502-
chunk=("sum_of_squares", "sum", "nanlen"),
503-
combine=("sum", "sum", "sum"),
502+
chunk=partial(var_chunk, skipna=False),
503+
numpy=partial(blockwise_or_numpy_var, skipna=False, std=True),
504+
combine=(_var_combine,),
504505
finalize=_std_finalize,
505-
fill_value=0,
506+
fill_value=((0, 0, 0),),
506507
final_fill_value=np.nan,
507-
dtypes=(None, None, np.intp),
508+
dtypes=(None,),
508509
final_dtype=np.floating,
509510
)
510511
nanstd = Aggregation(
511512
"nanstd",
512-
chunk=("nansum_of_squares", "nansum", "nanlen"),
513-
combine=("sum", "sum", "sum"),
513+
chunk=partial(var_chunk, skipna=True),
514+
numpy=partial(blockwise_or_numpy_var, skipna=True, std=True),
515+
combine=(_var_combine,),
514516
finalize=_std_finalize,
515-
fill_value=0,
517+
fill_value=((0, 0, 0),),
516518
final_fill_value=np.nan,
517-
dtypes=(None, None, np.intp),
519+
dtypes=(None,),
518520
final_dtype=np.floating,
519521
)
520522

flox/core.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@
4343
ScanState,
4444
_atleast_1d,
4545
_initialize_aggregation,
46-
blockwise_or_numpy_var,
4746
generic_aggregate,
47+
is_var_chunk_reduction,
4848
quantile_new_dims_func,
49-
var_chunk,
5049
)
5150
from .cache import memoize
5251
from .lib import ArrayLayer, dask_array_type, sparse_array_type
@@ -1291,7 +1290,7 @@ def chunk_reduce(
12911290
previous_reduction: T_Func = ""
12921291
for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes):
12931292
# UGLY! but this is because the `var` breaks our design assumptions
1294-
if empty and reduction is not var_chunk:
1293+
if empty and not is_var_chunk_reduction(reduction):
12951294
result = np.full(shape=final_array_shape, fill_value=fv, like=array)
12961295
elif is_nanlen(reduction) and is_nanlen(previous_reduction):
12971296
result = results["intermediates"][-1]
@@ -1301,7 +1300,7 @@ def chunk_reduce(
13011300
kw_func.update(kw)
13021301

13031302
# UGLY! but this is because the `var` breaks our design assumptions
1304-
if reduction is var_chunk or blockwise_or_numpy_var:
1303+
if is_var_chunk_reduction(reduction):
13051304
kw_func.update(engine=engine)
13061305

13071306
if callable(reduction):

tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def gen_array_by(size, func):
236236
@pytest.mark.parametrize("size", [(1, 12), (12,), (12, 9)])
237237
@pytest.mark.parametrize("nby", [1, 2, 3])
238238
@pytest.mark.parametrize("add_nan_by", [True, False])
239-
@pytest.mark.parametrize("func", ["nanvar"])
239+
@pytest.mark.parametrize("func", ["var", "nanvar", "std", "nanstd"])
240240
def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engine):
241241
if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1):
242242
pytest.skip()

0 commit comments

Comments
 (0)