3
3
import copy
4
4
import logging
5
5
import warnings
6
+ from collections .abc import Sequence
6
7
from dataclasses import dataclass
7
8
from functools import cached_property , partial
8
9
from typing import TYPE_CHECKING , Any , Callable , Literal , TypedDict
9
10
10
11
import numpy as np
12
+ import pandas as pd
11
13
from numpy .typing import ArrayLike , DTypeLike
12
14
13
15
from . import aggregate_flox , aggregate_npg , xrutils
19
21
20
22
21
23
logger = logging .getLogger ("flox" )
24
+ T_ScanBinaryOpMode = Literal ["apply_binary_op" , "concat_then_scan" ]
22
25
23
26
24
27
def _is_arg_reduction (func : str | Aggregation ) -> bool :
@@ -63,6 +66,9 @@ def generic_aggregate(
63
66
dtype = None ,
64
67
** kwargs ,
65
68
):
69
+ if func == "identity" :
70
+ return array
71
+
66
72
if engine == "flox" :
67
73
try :
68
74
method = getattr (aggregate_flox , func )
@@ -567,7 +573,171 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
567
573
mode = Aggregation (name = "mode" , fill_value = dtypes .NA , chunk = None , combine = None )
568
574
nanmode = Aggregation (name = "nanmode" , fill_value = dtypes .NA , chunk = None , combine = None )
569
575
570
- aggregations = {
576
+
577
+ @dataclass
578
+ class Scan :
579
+ # This dataclass is separate from Aggregations since there's not much in common
580
+ # between reductions and scans
581
+ name : str
582
+ # binary operation (e.g. np.add)
583
+ # Must be None for mode="concat_then_scan"
584
+ binary_op : Callable | None
585
+ # in-memory grouped scan function (e.g. cumsum)
586
+ scan : str
587
+ # Grouped reduction that yields the last result of the scan (e.g. sum)
588
+ reduction : str
589
+ # Identity element
590
+ identity : Any
591
+ # dtype of result
592
+ dtype : Any = None
593
+ # "Mode" of applying binary op.
594
+ # for np.add we apply the op directly to the `state` array and the `current` array.
595
+ # for ffill, bfill we concat `state` to `current` and then run the scan again.
596
+ mode : T_ScanBinaryOpMode = "apply_binary_op"
597
+ preprocess : Callable | None = None
598
+ finalize : Callable | None = None
599
+
600
+
601
+ def concatenate (arrays : Sequence [AlignedArrays ], axis = - 1 , out = None ) -> AlignedArrays :
602
+ group_idx = np .concatenate ([a .group_idx for a in arrays ], axis = axis )
603
+ array = np .concatenate ([a .array for a in arrays ], axis = axis )
604
+ return AlignedArrays (array = array , group_idx = group_idx )
605
+
606
+
607
+ @dataclass
608
+ class AlignedArrays :
609
+ """Simple Xarray DataArray type data class with two aligned arrays."""
610
+
611
+ array : np .ndarray
612
+ group_idx : np .ndarray
613
+
614
+ def __post_init__ (self ):
615
+ assert self .array .shape [- 1 ] == self .group_idx .size
616
+
617
+ def last (self ) -> AlignedArrays :
618
+ from flox .core import chunk_reduce
619
+
620
+ reduced = chunk_reduce (
621
+ self .array ,
622
+ self .group_idx ,
623
+ func = ("nanlast" ,),
624
+ axis = - 1 ,
625
+ # TODO: automate?
626
+ engine = "flox" ,
627
+ dtype = self .array .dtype ,
628
+ fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
629
+ expected_groups = None ,
630
+ )
631
+ return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
632
+
633
+
634
+ @dataclass
635
+ class ScanState :
636
+ """Dataclass representing intermediates for scan."""
637
+
638
+ # last value of each group seen so far
639
+ state : AlignedArrays | None
640
+ # intermediate result
641
+ result : AlignedArrays | None
642
+
643
+ def __post_init__ (self ):
644
+ assert (self .state is not None ) or (self .result is not None )
645
+
646
+
647
+ def reverse (a : AlignedArrays ) -> AlignedArrays :
648
+ a .group_idx = a .group_idx [::- 1 ]
649
+ a .array = a .array [::- 1 ]
650
+ return a
651
+
652
+
653
+ def scan_binary_op (left_state : ScanState , right_state : ScanState , * , agg : Scan ) -> ScanState :
654
+ from .core import reindex_
655
+
656
+ assert left_state .state is not None
657
+ left = left_state .state
658
+ right = right_state .result if right_state .result is not None else right_state .state
659
+ assert right is not None
660
+
661
+ if agg .mode == "apply_binary_op" :
662
+ assert agg .binary_op is not None
663
+ # Implements groupby binary operation.
664
+ reindexed = reindex_ (
665
+ left .array ,
666
+ from_ = pd .Index (left .group_idx ),
667
+ # can't use right.group_idx since we need to do the indexing later
668
+ to = pd .RangeIndex (right .group_idx .max () + 1 ),
669
+ fill_value = agg .identity ,
670
+ axis = - 1 ,
671
+ )
672
+ result = AlignedArrays (
673
+ array = agg .binary_op (reindexed [..., right .group_idx ], right .array ),
674
+ group_idx = right .group_idx ,
675
+ )
676
+
677
+ elif agg .mode == "concat_then_scan" :
678
+ # Implements the binary op portion of the scan as a concatenate-then-scan.
679
+ # This is useful for `ffill`, and presumably more generalized scans.
680
+ assert agg .binary_op is None
681
+ concat = concatenate ([left , right ], axis = - 1 )
682
+ final_value = generic_aggregate (
683
+ concat .group_idx ,
684
+ concat .array ,
685
+ func = agg .scan ,
686
+ axis = concat .array .ndim - 1 ,
687
+ engine = "flox" ,
688
+ fill_value = agg .identity ,
689
+ )
690
+ result = AlignedArrays (
691
+ array = final_value [..., left .group_idx .size :], group_idx = right .group_idx
692
+ )
693
+ else :
694
+ raise ValueError (f"Unknown binary op application mode: { agg .mode !r} " )
695
+
696
+ # This is quite important. We need to update the state seen so far and propagate that.
697
+ # So we must account for what we know when entering this function: i.e. `left`
698
+ # TODO: this is a bit wasteful since it will sort again, but for now let's focus on
699
+ # correctness and DRY
700
+ lasts = concatenate ([left , result ]).last ()
701
+
702
+ return ScanState (
703
+ state = lasts ,
704
+ # The binary op is called on the results of the reduction too when building up the tree.
705
+ # We need to be careful and assign those results only to `state` and not the final result.
706
+ # Up above, `result` is privileged when it exists.
707
+ result = None if right_state .result is None else result ,
708
+ )
709
+
710
+
711
+ # TODO: numpy_groupies cumsum is a broken when NaNs are present.
712
+ # cumsum = Scan("cumsum", binary_op=np.add, reduction="sum", scan="cumsum", identity=0)
713
+ nancumsum = Scan ("nancumsum" , binary_op = np .add , reduction = "nansum" , scan = "nancumsum" , identity = 0 )
714
+ # ffill uses the identity for scan, and then at the binary-op state,
715
+ # we concatenate the blockwise-reduced values with the original block,
716
+ # and then execute the scan
717
+ # TODO: consider adding chunk="identity" here, like with reductions as an optimization
718
+ ffill = Scan (
719
+ "ffill" ,
720
+ binary_op = None ,
721
+ reduction = "nanlast" ,
722
+ scan = "ffill" ,
723
+ identity = np .nan ,
724
+ mode = "concat_then_scan" ,
725
+ )
726
+ bfill = Scan (
727
+ "bfill" ,
728
+ binary_op = None ,
729
+ reduction = "nanlast" ,
730
+ scan = "ffill" ,
731
+ identity = np .nan ,
732
+ mode = "concat_then_scan" ,
733
+ preprocess = reverse ,
734
+ finalize = reverse ,
735
+ )
736
+ # TODO: not implemented in numpy_groupies
737
+ # cumprod = Scan("cumprod", binary_op=np.multiply, preop="prod", scan="cumprod")
738
+
739
+
740
+ AGGREGATIONS : dict [str , Aggregation | Scan ] = {
571
741
"any" : any_ ,
572
742
"all" : all_ ,
573
743
"count" : count ,
@@ -599,6 +769,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
599
769
"nanquantile" : nanquantile ,
600
770
"mode" : mode ,
601
771
"nanmode" : nanmode ,
772
+ # "cumsum": cumsum,
773
+ "nancumsum" : nancumsum ,
774
+ "ffill" : ffill ,
775
+ "bfill" : bfill ,
602
776
}
603
777
604
778
@@ -610,11 +784,14 @@ def _initialize_aggregation(
610
784
min_count : int ,
611
785
finalize_kwargs : dict [Any , Any ] | None ,
612
786
) -> Aggregation :
787
+ agg : Aggregation
613
788
if not isinstance (func , Aggregation ):
614
789
try :
615
790
# TODO: need better interface
616
791
# we set dtype, fillvalue on reduction later. so deepcopy now
617
- agg = copy .deepcopy (aggregations [func ])
792
+ agg_ = copy .deepcopy (AGGREGATIONS [func ])
793
+ assert isinstance (agg_ , Aggregation )
794
+ agg = agg_
618
795
except KeyError :
619
796
raise NotImplementedError (f"Reduction { func !r} not implemented yet" )
620
797
elif isinstance (func , Aggregation ):
0 commit comments