33from __future__ import annotations
44
55from abc import ABC , abstractmethod
6+ from collections .abc import Callable
67from typing import TYPE_CHECKING , Any
78
89import scipp as sc
@@ -36,14 +37,15 @@ def extract(self, data: sc.DataArray) -> Any:
3637 """
3738
3839 @abstractmethod
39- def get_required_timespan (self ) -> float | None :
40+ def get_required_timespan (self ) -> float :
4041 """
4142 Get the required timespan for this extractor.
4243
4344 Returns
4445 -------
4546 :
46- Required timespan in seconds, or None if no specific requirement.
47+ Required timespan in seconds. Return 0.0 for extractors that only
48+ need the latest value.
4749 """
4850
4951
@@ -61,9 +63,9 @@ def __init__(self, concat_dim: str = 'time') -> None:
6163 """
6264 self ._concat_dim = concat_dim
6365
64- def get_required_timespan (self ) -> float | None :
65- """Latest value has no specific timespan requirement ."""
66- return None
66+ def get_required_timespan (self ) -> float :
67+ """Latest value requires zero history ."""
68+ return 0.0
6769
6870 def extract (self , data : sc .DataArray ) -> Any :
6971 """Extract the latest value from the data, unwrapped."""
@@ -79,7 +81,7 @@ def extract(self, data: sc.DataArray) -> Any:
7981class FullHistoryExtractor (UpdateExtractor ):
8082 """Extracts the complete buffer history."""
8183
82- def get_required_timespan (self ) -> float | None :
84+ def get_required_timespan (self ) -> float :
8385 """Return infinite timespan to indicate wanting all history."""
8486 return float ('inf' )
8587
@@ -113,51 +115,43 @@ def __init__(
113115 self ._window_duration_seconds = window_duration_seconds
114116 self ._aggregation = aggregation
115117 self ._concat_dim = concat_dim
118+ self ._aggregator : Callable [[sc .DataArray , str ], sc .DataArray ] | None = None
116119
117- def get_required_timespan (self ) -> float | None :
120+ def get_required_timespan (self ) -> float :
118121 """Return the required window duration."""
119122 return self ._window_duration_seconds
120123
121124 def extract (self , data : sc .DataArray ) -> Any :
122125 """Extract a window of data and aggregate over the time dimension."""
123- # Check if concat dimension exists in the data
124- if not hasattr (data , 'dims' ) or self ._concat_dim not in data .dims :
125- # Data doesn't have the expected dimension structure, return as-is
126- return data
127-
128- # Extract time window
129- if not hasattr (data , 'coords' ) or self ._concat_dim not in data .coords :
130- # No time coordinate - can't do time-based windowing, return all data
131- windowed_data = data
132- else :
133- # Calculate cutoff time using scipp's unit handling
134- time_coord = data .coords [self ._concat_dim ]
135- latest_time = time_coord [- 1 ]
136- duration = sc .scalar (self ._window_duration_seconds , unit = 's' ).to (
137- unit = time_coord .unit
138- )
139- windowed_data = data [self ._concat_dim , latest_time - duration :]
140-
141- # Determine aggregation method
142- agg_method = self ._aggregation
143- if agg_method == WindowAggregation .auto :
144- # Use nansum if data is dimensionless (counts), else nanmean
145- if hasattr (windowed_data , 'unit' ) and windowed_data .unit == '1' :
146- agg_method = WindowAggregation .nansum
126+ # Calculate cutoff time using scipp's unit handling
127+ time_coord = data .coords [self ._concat_dim ]
128+ latest_time = time_coord [- 1 ]
129+ duration = sc .scalar (self ._window_duration_seconds , unit = 's' ).to (
130+ unit = time_coord .unit
131+ )
132+ windowed_data = data [self ._concat_dim , latest_time - duration :]
133+
134+ # Resolve and cache aggregator function on first call
135+ if self ._aggregator is None :
136+ if self ._aggregation == WindowAggregation .auto :
137+ aggregation = (
138+ WindowAggregation .nansum
139+ if windowed_data .unit == 'counts'
140+ else WindowAggregation .nanmean
141+ )
147142 else :
148- agg_method = WindowAggregation .nanmean
149-
150- # Aggregate over the concat dimension
151- if agg_method == WindowAggregation .sum :
152- return windowed_data .sum (self ._concat_dim )
153- elif agg_method == WindowAggregation .nansum :
154- return windowed_data .nansum (self ._concat_dim )
155- elif agg_method == WindowAggregation .mean :
156- return windowed_data .mean (self ._concat_dim )
157- elif agg_method == WindowAggregation .nanmean :
158- return windowed_data .nanmean (self ._concat_dim )
159- else :
160- raise ValueError (f"Unknown aggregation method: { agg_method } " )
143+ aggregation = self ._aggregation
144+ aggregators = {
145+ WindowAggregation .sum : sc .sum ,
146+ WindowAggregation .nansum : sc .nansum ,
147+ WindowAggregation .mean : sc .mean ,
148+ WindowAggregation .nanmean : sc .nanmean ,
149+ }
150+ self ._aggregator = aggregators .get (aggregation )
151+ if self ._aggregator is None :
152+ raise ValueError (f"Unknown aggregation method: { self ._aggregation } " )
153+
154+ return self ._aggregator (windowed_data , self ._concat_dim )
161155
162156
163157def create_extractors_from_params (
0 commit comments