@@ -343,7 +343,9 @@ def _mean_finalize(sum_, count):
343
343
)
344
344
345
345
346
- def var_chunk (group_idx , array , * , engine : str , axis = - 1 , size = None , fill_value = None , dtype = None ):
346
+ def var_chunk (
347
+ group_idx , array , * , skipna : bool , engine : str , axis = - 1 , size = None , fill_value = None , dtype = None
348
+ ):
347
349
from .aggregate_flox import MultiArray
348
350
349
351
# Calculate length and sum - important for the adjustment terms to sum squared deviations
@@ -361,7 +363,7 @@ def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=N
361
363
array_sums = generic_aggregate (
362
364
group_idx ,
363
365
array ,
364
- func = "nansum" ,
366
+ func = "nansum" if skipna else "sum" ,
365
367
engine = engine ,
366
368
axis = axis ,
367
369
size = size ,
@@ -375,7 +377,7 @@ def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=N
375
377
sum_squared_deviations = generic_aggregate (
376
378
group_idx ,
377
379
(array - array_means [..., group_idx ]) ** 2 ,
378
- func = "nansum" ,
380
+ func = "nansum" if skipna else "sum" ,
379
381
engine = engine ,
380
382
axis = axis ,
381
383
size = size ,
@@ -448,73 +450,73 @@ def clip_first(array, n=1):
448
450
# return result
449
451
450
452
453
+ def is_var_chunk_reduction (agg : Callable ) -> bool :
454
+ if isinstance (agg , partial ):
455
+ agg = agg .func
456
+ return agg is blockwise_or_numpy_var or agg is var_chunk
457
+
458
+
451
459
def _var_finalize (multiarray , ddof = 0 ):
452
460
den = multiarray .arrays [2 ] - ddof
453
461
# preserve nans for groups with 0 obs; so these values are -ddof
454
462
den [den < 0 ] = 0
455
463
return multiarray .arrays [0 ] / den
456
464
457
465
458
- def _std_finalize (sumsq , sum_ , count , ddof = 0 ):
459
- return np .sqrt (_var_finalize (sumsq , sum_ , count , ddof ))
466
+ def _std_finalize (multiarray , ddof = 0 ):
467
+ return np .sqrt (_var_finalize (multiarray , ddof ))
468
+
469
+
470
+ def blockwise_or_numpy_var (* args , skipna : bool , ddof = 0 , std = False , ** kwargs ):
471
+ res = _var_finalize (var_chunk (* args , skipna = skipna , ** kwargs ), ddof )
472
+ return np .sqrt (res ) if std else res
460
473
461
474
462
475
# var, std always promote to float, so we set nan
463
476
var = Aggregation (
464
477
"var" ,
465
- chunk = ("sum_of_squares" , "sum" , "nanlen" ),
466
- combine = ("sum" , "sum" , "sum" ),
478
+ chunk = partial (var_chunk , skipna = False ),
479
+ numpy = partial (blockwise_or_numpy_var , skipna = False ),
480
+ combine = (_var_combine ,),
467
481
finalize = _var_finalize ,
468
- fill_value = 0 ,
482
+ fill_value = (( 0 , 0 , 0 ),) ,
469
483
final_fill_value = np .nan ,
470
- dtypes = (None , None , np . intp ),
484
+ dtypes = (None ,),
471
485
final_dtype = np .floating ,
472
486
)
473
- # nanvar = Aggregation(
474
- # "nanvar",
475
- # chunk=("nansum_of_squares", "nansum", "nanlen"),
476
- # combine=("sum", "sum", "sum"),
477
- # finalize=_var_finalize,
478
- # fill_value=0,
479
- # final_fill_value=np.nan,
480
- # dtypes=(None, None, np.intp),
481
- # final_dtype=np.floating,
482
- # )
483
-
484
-
485
- def blockwise_or_numpy_var (* args , ddof = 0 , ** kwargs ):
486
- return _var_finalize (var_chunk (* args , ** kwargs ), ddof )
487
-
488
487
489
488
nanvar = Aggregation (
490
489
"nanvar" ,
491
- chunk = var_chunk ,
492
- numpy = blockwise_or_numpy_var ,
490
+ chunk = partial ( var_chunk , skipna = True ) ,
491
+ numpy = partial ( blockwise_or_numpy_var , skipna = True ) ,
493
492
combine = (_var_combine ,),
494
493
finalize = _var_finalize ,
495
494
fill_value = ((0 , 0 , 0 ),),
496
495
final_fill_value = np .nan ,
497
496
dtypes = (None ,),
498
497
final_dtype = np .floating ,
499
498
)
499
+
500
500
std = Aggregation (
501
501
"std" ,
502
- chunk = ("sum_of_squares" , "sum" , "nanlen" ),
503
- combine = ("sum" , "sum" , "sum" ),
502
+ chunk = partial (var_chunk , skipna = False ),
503
+ numpy = partial (blockwise_or_numpy_var , skipna = False , std = True ),
504
+ combine = (_var_combine ,),
504
505
finalize = _std_finalize ,
505
- fill_value = 0 ,
506
+ fill_value = (( 0 , 0 , 0 ),) ,
506
507
final_fill_value = np .nan ,
507
- dtypes = (None , None , np . intp ),
508
+ dtypes = (None ,),
508
509
final_dtype = np .floating ,
509
510
)
510
511
nanstd = Aggregation (
511
512
"nanstd" ,
512
- chunk = ("nansum_of_squares" , "nansum" , "nanlen" ),
513
- combine = ("sum" , "sum" , "sum" ),
513
+ chunk = partial (var_chunk , skipna = True ),
514
+ numpy = partial (blockwise_or_numpy_var , skipna = True , std = True ),
515
+ combine = (_var_combine ,),
514
516
finalize = _std_finalize ,
515
- fill_value = 0 ,
517
+ fill_value = (( 0 , 0 , 0 ),) ,
516
518
final_fill_value = np .nan ,
517
- dtypes = (None , None , np . intp ),
519
+ dtypes = (None ,),
518
520
final_dtype = np .floating ,
519
521
)
520
522
0 commit comments