Skip to content

Commit 6380966

Browse files
committed
Optimize WindowAggregatingExtractor with caching and dict lookup
- Cache aggregator function directly (sc.sum, sc.nansum, etc.) to avoid repeated unit checks and dict lookups on every extract() call - Use scipp free functions (sc.sum, sc.nansum, etc.) instead of lambdas for cleaner code - Simplify auto mode to check for 'counts' unit instead of dimensionless - Update get_required_timespan() to always return float (0.0 for latest, float('inf') for full history) instead of Optional[float] Original prompt: Can we use a dict in WindowAggregatingExtractor to speedup the aggregator lookup? Or maybe it could/should be cashed (unit cannot change during stream). Follow-up: I suggest you change to Callable[[sc.DataArray, str], sc.DataArray] and remove the lambdas?
1 parent 9eedaaa commit 6380966

File tree

1 file changed

+38
-44
lines changed

1 file changed

+38
-44
lines changed

src/ess/livedata/dashboard/extractors.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6+
from collections.abc import Callable
67
from typing import TYPE_CHECKING, Any
78

89
import scipp as sc
@@ -36,14 +37,15 @@ def extract(self, data: sc.DataArray) -> Any:
3637
"""
3738

3839
@abstractmethod
39-
def get_required_timespan(self) -> float | None:
40+
def get_required_timespan(self) -> float:
4041
"""
4142
Get the required timespan for this extractor.
4243
4344
Returns
4445
-------
4546
:
46-
Required timespan in seconds, or None if no specific requirement.
47+
Required timespan in seconds. Return 0.0 for extractors that only
48+
need the latest value.
4749
"""
4850

4951

@@ -61,9 +63,9 @@ def __init__(self, concat_dim: str = 'time') -> None:
6163
"""
6264
self._concat_dim = concat_dim
6365

64-
def get_required_timespan(self) -> float | None:
65-
"""Latest value has no specific timespan requirement."""
66-
return None
66+
def get_required_timespan(self) -> float:
67+
"""Latest value requires zero history."""
68+
return 0.0
6769

6870
def extract(self, data: sc.DataArray) -> Any:
6971
"""Extract the latest value from the data, unwrapped."""
@@ -79,7 +81,7 @@ def extract(self, data: sc.DataArray) -> Any:
7981
class FullHistoryExtractor(UpdateExtractor):
8082
"""Extracts the complete buffer history."""
8183

82-
def get_required_timespan(self) -> float | None:
84+
def get_required_timespan(self) -> float:
8385
"""Return infinite timespan to indicate wanting all history."""
8486
return float('inf')
8587

@@ -113,51 +115,43 @@ def __init__(
113115
self._window_duration_seconds = window_duration_seconds
114116
self._aggregation = aggregation
115117
self._concat_dim = concat_dim
118+
self._aggregator: Callable[[sc.DataArray, str], sc.DataArray] | None = None
116119

117-
def get_required_timespan(self) -> float | None:
120+
def get_required_timespan(self) -> float:
118121
"""Return the required window duration."""
119122
return self._window_duration_seconds
120123

121124
def extract(self, data: sc.DataArray) -> Any:
122125
"""Extract a window of data and aggregate over the time dimension."""
123-
# Check if concat dimension exists in the data
124-
if not hasattr(data, 'dims') or self._concat_dim not in data.dims:
125-
# Data doesn't have the expected dimension structure, return as-is
126-
return data
127-
128-
# Extract time window
129-
if not hasattr(data, 'coords') or self._concat_dim not in data.coords:
130-
# No time coordinate - can't do time-based windowing, return all data
131-
windowed_data = data
132-
else:
133-
# Calculate cutoff time using scipp's unit handling
134-
time_coord = data.coords[self._concat_dim]
135-
latest_time = time_coord[-1]
136-
duration = sc.scalar(self._window_duration_seconds, unit='s').to(
137-
unit=time_coord.unit
138-
)
139-
windowed_data = data[self._concat_dim, latest_time - duration :]
140-
141-
# Determine aggregation method
142-
agg_method = self._aggregation
143-
if agg_method == WindowAggregation.auto:
144-
# Use nansum if data is dimensionless (counts), else nanmean
145-
if hasattr(windowed_data, 'unit') and windowed_data.unit == '1':
146-
agg_method = WindowAggregation.nansum
126+
# Calculate cutoff time using scipp's unit handling
127+
time_coord = data.coords[self._concat_dim]
128+
latest_time = time_coord[-1]
129+
duration = sc.scalar(self._window_duration_seconds, unit='s').to(
130+
unit=time_coord.unit
131+
)
132+
windowed_data = data[self._concat_dim, latest_time - duration :]
133+
134+
# Resolve and cache aggregator function on first call
135+
if self._aggregator is None:
136+
if self._aggregation == WindowAggregation.auto:
137+
aggregation = (
138+
WindowAggregation.nansum
139+
if windowed_data.unit == 'counts'
140+
else WindowAggregation.nanmean
141+
)
147142
else:
148-
agg_method = WindowAggregation.nanmean
149-
150-
# Aggregate over the concat dimension
151-
if agg_method == WindowAggregation.sum:
152-
return windowed_data.sum(self._concat_dim)
153-
elif agg_method == WindowAggregation.nansum:
154-
return windowed_data.nansum(self._concat_dim)
155-
elif agg_method == WindowAggregation.mean:
156-
return windowed_data.mean(self._concat_dim)
157-
elif agg_method == WindowAggregation.nanmean:
158-
return windowed_data.nanmean(self._concat_dim)
159-
else:
160-
raise ValueError(f"Unknown aggregation method: {agg_method}")
143+
aggregation = self._aggregation
144+
aggregators = {
145+
WindowAggregation.sum: sc.sum,
146+
WindowAggregation.nansum: sc.nansum,
147+
WindowAggregation.mean: sc.mean,
148+
WindowAggregation.nanmean: sc.nanmean,
149+
}
150+
self._aggregator = aggregators.get(aggregation)
151+
if self._aggregator is None:
152+
raise ValueError(f"Unknown aggregation method: {self._aggregation}")
153+
154+
return self._aggregator(windowed_data, self._concat_dim)
161155

162156

163157
def create_extractors_from_params(

0 commit comments

Comments
 (0)