diff --git a/src/ess/livedata/config/workflows.py b/src/ess/livedata/config/workflows.py index 29c8b51bb..4c24dd310 100644 --- a/src/ess/livedata/config/workflows.py +++ b/src/ess/livedata/config/workflows.py @@ -76,7 +76,10 @@ def is_empty(self) -> bool: def _get_value(self) -> sc.DataArray: if self._to_nxlog is None: raise ValueError("No data accumulated") - return self._to_nxlog.get() + # Return latest value. Will be aggregated into a timeseries in frontend (if a + # plot requests it). This accumulator may be fully replaced once it is clear how + # we want to handle obtaining the full history (e.g., after frontend restarts). + return self._to_nxlog.get()[-1] def _do_push(self, value: sc.DataArray) -> None: if self._to_nxlog is None: diff --git a/src/ess/livedata/core/job.py b/src/ess/livedata/core/job.py index ca12ec940..03356405b 100644 --- a/src/ess/livedata/core/job.py +++ b/src/ess/livedata/core/job.py @@ -184,7 +184,11 @@ def add(self, data: JobData) -> JobReply: remapped_aux_data[field_name] = value # Pass data to workflow with field names (not stream names) - self._processor.accumulate({**data.primary_data, **remapped_aux_data}) + self._processor.accumulate( + {**data.primary_data, **remapped_aux_data}, + start_time=data.start_time, + end_time=data.end_time, + ) if data.is_active(): if self._start_time is None: self._start_time = data.start_time diff --git a/src/ess/livedata/dashboard/correlation_histogram.py b/src/ess/livedata/dashboard/correlation_histogram.py index b20411138..58701cc30 100644 --- a/src/ess/livedata/dashboard/correlation_histogram.py +++ b/src/ess/livedata/dashboard/correlation_histogram.py @@ -392,11 +392,25 @@ def add_correlation_processor( items: dict[ResultKey, sc.DataArray], ) -> None: """Add a correlation histogram processor with DataService subscription.""" + from .extractors import FullHistoryExtractor + self._processors.append(processor) - # Create subscriber that merges data and sends to processor + # Create subscriber that merges data and sends to processor. + # Use FullHistoryExtractor to get complete timeseries history needed for + # correlation histogram computation. + # TODO We should update the plotter to operate more efficiently by simply + # subscribing to the changes. This will likely require a new extractor type as + # well as changes in the plotter, so we defer this for now. assembler = MergingStreamAssembler(set(items)) - subscriber = DataSubscriber(assembler, processor) + extractors = {key: FullHistoryExtractor() for key in items} + + # Create factory that sends initial data to processor and returns it + def processor_pipe_factory(data: dict[ResultKey, sc.DataArray]): + processor.send(data) + return processor + + subscriber = DataSubscriber(assembler, processor_pipe_factory, extractors) self._data_service.register_subscriber(subscriber) def get_timeseries(self) -> list[ResultKey]: @@ -412,7 +426,23 @@ def create_2d_config(self) -> CorrelationHistogramConfigurationAdapter: def _is_timeseries(da: sc.DataArray) -> bool: - return da.dims == ('time',) and 'time' in da.coords + """Check if data represents a timeseries. + + When DataService uses LatestValueExtractor (default), it returns the latest value + from a timeseries buffer as a 0D scalar with a time coordinate. This function + identifies such values as originating from a timeseries. + + Parameters + ---------- + da: + DataArray to check. + + Returns + ------- + : + True if the data is a 0D scalar with a time coordinate. + """ + return da.ndim == 0 and 'time' in da.coords class CorrelationHistogramProcessor: diff --git a/src/ess/livedata/dashboard/data_service.py b/src/ess/livedata/dashboard/data_service.py index 2d8098cd6..06a9f8649 100644 --- a/src/ess/livedata/dashboard/data_service.py +++ b/src/ess/livedata/dashboard/data_service.py @@ -2,32 +2,75 @@ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from collections import UserDict -from collections.abc import Callable, Hashable +from abc import ABC, abstractmethod +from collections.abc import Callable, Hashable, Iterator, Mapping, MutableMapping from contextlib import contextmanager -from typing import TypeVar +from typing import Any, Generic, TypeVar -from .data_subscriber import DataSubscriber +from .extractors import LatestValueExtractor, UpdateExtractor +from .temporal_buffer_manager import TemporalBufferManager K = TypeVar('K', bound=Hashable) V = TypeVar('V') -class DataService(UserDict[K, V]): +class DataServiceSubscriber(ABC, Generic[K]): + """Base class for data service subscribers with cached keys and extractors.""" + + def __init__(self) -> None: + """Initialize subscriber and cache keys from extractors.""" + # Cache keys from extractors to avoid repeated computation + self._keys = set(self.extractors.keys()) + + @property + def keys(self) -> set[K]: + """Return the set of data keys this subscriber depends on.""" + return self._keys + + @property + @abstractmethod + def extractors(self) -> Mapping[K, UpdateExtractor]: + """ + Return extractors for obtaining data views. + + Returns a mapping from key to the extractor to use for that key. + """ + + @abstractmethod + def trigger(self, store: dict[K, Any]) -> None: + """Trigger the subscriber with updated data.""" + + +class DataService(MutableMapping[K, V]): """ A service for managing and retrieving data and derived data. New data is set from upstream Kafka topics. Subscribers are typically plots that provide a live view of the data. + + Uses buffers internally for storage, but presents a dict-like interface + that returns the latest value for each key. """ - def __init__(self) -> None: - super().__init__() - self._subscribers: list[DataSubscriber[K]] = [] - self._key_change_subscribers: list[Callable[[set[K], set[K]], None]] = [] + def __init__( + self, + buffer_manager: TemporalBufferManager | None = None, + ) -> None: + """ + Initialize DataService. + + Parameters + ---------- + buffer_manager: + Manager for buffer sizing. If None, creates a new TemporalBufferManager. + """ + if buffer_manager is None: + buffer_manager = TemporalBufferManager() + self._buffer_manager = buffer_manager + self._default_extractor = LatestValueExtractor() + self._subscribers: list[DataServiceSubscriber[K]] = [] + self._update_callbacks: list[Callable[[set[K]], None]] = [] self._pending_updates: set[K] = set() - self._pending_key_additions: set[K] = set() - self._pending_key_removals: set[K] = set() self._transaction_depth = 0 @contextmanager @@ -48,30 +91,94 @@ def transaction(self): def _in_transaction(self) -> bool: return self._transaction_depth > 0 - def register_subscriber(self, subscriber: DataSubscriber[K]) -> None: + def _get_extractors(self, key: K) -> list[UpdateExtractor]: """ - Register a subscriber for updates. + Collect extractors for a key from all subscribers. + + Examines all subscribers that need this key. + + Parameters + ---------- + key: + The key to collect extractors for. + + Returns + ------- + : + List of extractors from all subscribers for this key. + """ + extractors = [] + + for subscriber in self._subscribers: + subscriber_extractors = subscriber.extractors + if key in subscriber_extractors: + extractor = subscriber_extractors[key] + extractors.append(extractor) + + return extractors + + def _build_subscriber_data( + self, subscriber: DataServiceSubscriber[K] + ) -> dict[K, Any]: + """ + Extract data for a subscriber based on its extractors. Parameters ---------- subscriber: - The subscriber to register. Must implement the DataSubscriber interface. + The subscriber to extract data for. + + Returns + ------- + : + Dictionary mapping keys to extracted data (None values filtered out). """ - self._subscribers.append(subscriber) + subscriber_data = {} - def subscribe_to_changed_keys( - self, subscriber: Callable[[set[K], set[K]], None] - ) -> None: + for key, extractor in subscriber.extractors.items(): + buffered_data = self._buffer_manager.get_buffered_data(key) + if buffered_data is not None: + subscriber_data[key] = extractor.extract(buffered_data) + + return subscriber_data + + def register_subscriber(self, subscriber: DataServiceSubscriber[K]) -> None: """ - Register a subscriber for key change updates (additions/removals). + Register a subscriber for updates with extractor-based data access. + + Triggers the subscriber immediately with existing data using its extractors. Parameters ---------- subscriber: - A callable that accepts two sets: added_keys and removed_keys. + The subscriber to register. """ - self._key_change_subscribers.append(subscriber) - subscriber(set(self.data.keys()), set()) + self._subscribers.append(subscriber) + + # Add extractors for keys this subscriber needs + for key in subscriber.keys: + if key in self._buffer_manager: + extractor = subscriber.extractors[key] + self._buffer_manager.add_extractor(key, extractor) + + # Trigger immediately with existing data using subscriber's extractors + existing_data = self._build_subscriber_data(subscriber) + subscriber.trigger(existing_data) + + def register_update_callback(self, callback: Callable[[set[K]], None]) -> None: + """ + Register a callback for key update notifications. + + Callback receives only the set of updated key names, not the data. + Use this for infrastructure that needs to know what changed but will + query data itself. + + Parameters + ---------- + callback: + Callable that accepts a set of updated keys. + """ + self._update_callbacks.append(callback) def _notify_subscribers(self, updated_keys: set[K]) -> None: """ @@ -82,40 +189,46 @@ def _notify_subscribers(self, updated_keys: set[K]) -> None: updated_keys The set of data keys that were updated. """ + # Notify extractor-based subscribers for subscriber in self._subscribers: - if not isinstance(subscriber, DataSubscriber): - subscriber(updated_keys) - continue if updated_keys & subscriber.keys: - # Pass only the data that the subscriber is interested in - subscriber_data = { - key: self.data[key] for key in subscriber.keys if key in self.data - } + subscriber_data = self._build_subscriber_data(subscriber) subscriber.trigger(subscriber_data) - def _notify_key_change_subscribers(self) -> None: - """Notify subscribers about key changes (additions/removals).""" - if not self._pending_key_additions and not self._pending_key_removals: - return + # Notify update callbacks with just key names + for callback in self._update_callbacks: + callback(updated_keys) - for subscriber in self._key_change_subscribers: - subscriber( - self._pending_key_additions.copy(), self._pending_key_removals.copy() - ) + def __getitem__(self, key: K) -> V: + """Get the latest value for a key.""" + buffered_data = self._buffer_manager.get_buffered_data(key) + if buffered_data is None: + raise KeyError(key) + return self._default_extractor.extract(buffered_data) def __setitem__(self, key: K, value: V) -> None: - if key not in self.data: - self._pending_key_additions.add(key) - super().__setitem__(key, value) + """Set a value, storing it in a buffer.""" + if key not in self._buffer_manager: + extractors = self._get_extractors(key) + self._buffer_manager.create_buffer(key, extractors) + self._buffer_manager.update_buffer(key, value) self._pending_updates.add(key) self._notify_if_not_in_transaction() def __delitem__(self, key: K) -> None: - self._pending_key_removals.add(key) - super().__delitem__(key) + """Delete a key and its buffer.""" + self._buffer_manager.delete_buffer(key) self._pending_updates.add(key) self._notify_if_not_in_transaction() + def __iter__(self) -> Iterator[K]: + """Iterate over keys.""" + return iter(self._buffer_manager) + + def __len__(self) -> int: + """Return the number of keys.""" + return len(self._buffer_manager) + def _notify_if_not_in_transaction(self) -> None: """Notify subscribers if not in a transaction.""" if not self._in_transaction: @@ -127,6 +240,3 @@ def _notify(self) -> None: pending = set(self._pending_updates) self._pending_updates.clear() self._notify_subscribers(pending) - self._notify_key_change_subscribers() - self._pending_key_additions.clear() - self._pending_key_removals.clear() diff --git a/src/ess/livedata/dashboard/data_subscriber.py b/src/ess/livedata/dashboard/data_subscriber.py index 29993914b..8202db534 100644 --- a/src/ess/livedata/dashboard/data_subscriber.py +++ b/src/ess/livedata/dashboard/data_subscriber.py @@ -3,10 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Hashable +from collections.abc import Callable, Hashable, Mapping from typing import Any, Generic, Protocol, TypeVar from ess.livedata.config.workflow_spec import ResultKey +from ess.livedata.dashboard.data_service import DataServiceSubscriber +from ess.livedata.dashboard.extractors import UpdateExtractor class PipeBase(Protocol): @@ -40,6 +42,7 @@ def __init__(self, data: Any) -> None: Key = TypeVar('Key', bound=Hashable) +P = TypeVar('P', bound=PipeBase) class StreamAssembler(ABC, Generic[Key]): @@ -85,27 +88,45 @@ def assemble(self, data: dict[Key, Any]) -> Any: """ -class DataSubscriber(Generic[Key]): +class DataSubscriber(DataServiceSubscriber[Key], Generic[Key, P]): """Unified subscriber that uses a StreamAssembler to process data.""" - def __init__(self, assembler: StreamAssembler[Key], pipe: PipeBase) -> None: + def __init__( + self, + assembler: StreamAssembler[Key], + pipe_factory: Callable[[dict[Key, Any]], P], + extractors: Mapping[Key, UpdateExtractor], + ) -> None: """ - Initialize the subscriber with an assembler and pipe. + Initialize the subscriber with an assembler and pipe factory. Parameters ---------- assembler: The assembler responsible for processing the data. - pipe: - The pipe to send assembled data to. + pipe_factory: + Factory function to create the pipe on first trigger. + extractors: + Mapping from keys to their UpdateExtractor instances. """ self._assembler = assembler - self._pipe = pipe + self._pipe_factory = pipe_factory + self._pipe: P | None = None + self._extractors = extractors + # Initialize parent class to cache keys + super().__init__() @property - def keys(self) -> set[Key]: - """Return the set of data keys this subscriber depends on.""" - return self._assembler.keys + def extractors(self) -> Mapping[Key, UpdateExtractor]: + """Return extractors for obtaining data views.""" + return self._extractors + + @property + def pipe(self) -> P: + """Return the pipe (must be created by first trigger).""" + if self._pipe is None: + raise RuntimeError("Pipe not yet initialized - subscriber not triggered") + return self._pipe def trigger(self, store: dict[Key, Any]) -> None: """ @@ -118,7 +139,13 @@ def trigger(self, store: dict[Key, Any]) -> None: """ data = {key: store[key] for key in self.keys if key in store} assembled_data = self._assembler.assemble(data) - self._pipe.send(assembled_data) + + if self._pipe is None: + # First trigger - create pipe with correctly extracted data + self._pipe = self._pipe_factory(assembled_data) + else: + # Subsequent triggers - send to existing pipe + self._pipe.send(assembled_data) class MergingStreamAssembler(StreamAssembler): diff --git a/src/ess/livedata/dashboard/extractors.py b/src/ess/livedata/dashboard/extractors.py new file mode 100644 index 000000000..3cde0424d --- /dev/null +++ b/src/ess/livedata/dashboard/extractors.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import scipp as sc + +from .plot_params import WindowAggregation + + +class UpdateExtractor(ABC): + """Extracts a specific view of buffered data.""" + + @abstractmethod + def extract(self, data: sc.DataArray) -> Any: + """ + Extract data from buffered data. + + Parameters + ---------- + data: + The buffered data to extract from. + + Returns + ------- + : + The extracted data. + """ + + @abstractmethod + def get_required_timespan(self) -> float: + """ + Get the required timespan for this extractor. + + Returns + ------- + : + Required timespan in seconds. Return 0.0 for extractors that only + need the latest value. + """ + + +class LatestValueExtractor(UpdateExtractor): + """Extracts the latest single value, unwrapping the concat dimension.""" + + def __init__(self, concat_dim: str = 'time') -> None: + """ + Initialize latest value extractor. + + Parameters + ---------- + concat_dim: + The dimension along which data is concatenated. + """ + self._concat_dim = concat_dim + + def get_required_timespan(self) -> float: + """Latest value requires zero history.""" + return 0.0 + + def extract(self, data: sc.DataArray) -> Any: + """Extract the latest value from the data, unwrapped.""" + return data[self._concat_dim, -1] if self._concat_dim in data.dims else data + + +class FullHistoryExtractor(UpdateExtractor): + """Extracts the complete buffer history.""" + + def get_required_timespan(self) -> float: + """Return infinite timespan to indicate wanting all history.""" + return float('inf') + + def extract(self, data: sc.DataArray) -> Any: + """Extract all data from the buffer.""" + return data + + +class WindowAggregatingExtractor(UpdateExtractor): + """Extracts a window from the buffer and aggregates over the time dimension.""" + + def __init__( + self, + window_duration_seconds: float, + aggregation: WindowAggregation = WindowAggregation.auto, + concat_dim: str = 'time', + ) -> None: + """ + Initialize window aggregating extractor. + + Parameters + ---------- + window_duration_seconds: + Time duration to extract from the end of the buffer (seconds). + aggregation: + Aggregation method. WindowAggregation.auto uses 'nansum' if data unit + is counts, else 'nanmean'. + concat_dim: + Name of the dimension to aggregate over. + """ + self._window_duration_seconds = window_duration_seconds + self._aggregation = aggregation + self._concat_dim = concat_dim + self._aggregator: Callable[[sc.DataArray, str], sc.DataArray] | None = None + self._duration: sc.Variable | None = None + + def get_required_timespan(self) -> float: + """Return the required window duration.""" + return self._window_duration_seconds + + def extract(self, data: sc.DataArray) -> Any: + """Extract a window of data and aggregate over the time dimension.""" + # Calculate cutoff time + time_coord = data.coords[self._concat_dim] + if self._duration is None: + duration_scalar = sc.scalar(self._window_duration_seconds, unit='s') + if time_coord.dtype == sc.DType.datetime64: + self._duration = duration_scalar.to(unit=time_coord.unit, dtype='int64') + else: + self._duration = duration_scalar.to(unit=time_coord.unit) + + # Estimate frame period from median interval to handle timing noise. + # Shift cutoff by half period to place boundary between frame slots, + # avoiding inclusion of extra frames due to timing jitter. + latest_time = time_coord[-1] + if len(time_coord) > 1: + intervals = time_coord[1:] - time_coord[:-1] + median_interval = sc.median(intervals) + half_median = 0.5 * median_interval + # datetime64 arithmetic requires int64, not float64 + if time_coord.dtype == sc.DType.datetime64: + half_median = half_median.astype('int64') + cutoff_time = latest_time - self._duration + half_median + # Clamp to ensure at least latest frame included + # (handles narrow windows where duration < median_interval) + if cutoff_time > latest_time: + cutoff_time = latest_time + else: + # Single frame: use duration-based cutoff + cutoff_time = latest_time - self._duration + + # Use label-based slicing with inclusive lower bound. If timestamps were precise + # we would actually want exclusive lower bound, but since there is jitter anyway + # the cutoff shift above should handle that well enough. + windowed_data = data[self._concat_dim, cutoff_time:] + + # Resolve and cache aggregator function on first call + if self._aggregator is None: + if self._aggregation == WindowAggregation.auto: + aggregation = ( + WindowAggregation.nansum + if windowed_data.unit == 'counts' + else WindowAggregation.nanmean + ) + else: + aggregation = self._aggregation + aggregators = { + WindowAggregation.sum: sc.sum, + WindowAggregation.nansum: sc.nansum, + WindowAggregation.mean: sc.mean, + WindowAggregation.nanmean: sc.nanmean, + } + self._aggregator = aggregators.get(aggregation) + if self._aggregator is None: + raise ValueError(f"Unknown aggregation method: {self._aggregation}") + + return self._aggregator(windowed_data, self._concat_dim) diff --git a/src/ess/livedata/dashboard/job_service.py b/src/ess/livedata/dashboard/job_service.py index 7b6fae843..1bce2062e 100644 --- a/src/ess/livedata/dashboard/job_service.py +++ b/src/ess/livedata/dashboard/job_service.py @@ -37,7 +37,7 @@ def __init__( self._removed_jobs: set[JobId] = set() self._job_data_update_subscribers: list[Callable[[], None]] = [] self._job_status_update_subscribers: list[Callable[[], None]] = [] - self._data_service.register_subscriber(self.data_updated) + self._data_service.register_update_callback(self.data_updated) @property def job_data(self) -> dict[JobNumber, dict[SourceName, SourceData]]: diff --git a/src/ess/livedata/dashboard/plot_params.py b/src/ess/livedata/dashboard/plot_params.py index 7ffb2b5de..c7d34284a 100644 --- a/src/ess/livedata/dashboard/plot_params.py +++ b/src/ess/livedata/dashboard/plot_params.py @@ -2,18 +2,45 @@ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) """Param models for configuring plotters via widgets.""" +from __future__ import annotations + import enum +from enum import StrEnum +from typing import TYPE_CHECKING import pydantic from ..config.roi_names import get_roi_mapper +if TYPE_CHECKING: + from ess.livedata.config.workflow_spec import ResultKey + + from .extractors import UpdateExtractor + from .plotting import PlotterSpec + def _get_default_max_roi_count() -> int: """Get the default maximum ROI count from the mapper configuration.""" return get_roi_mapper().total_rois +class WindowMode(str, enum.Enum): + """Enumeration of extraction modes.""" + + latest = 'latest' + window = 'window' + + +class WindowAggregation(StrEnum): + """Enumeration of aggregation methods for window mode.""" + + auto = 'auto' + nansum = 'nansum' + nanmean = 'nanmean' + sum = 'sum' + mean = 'mean' + + class PlotScale(str, enum.Enum): """Enumeration of plot scales.""" @@ -111,6 +138,32 @@ class LayoutParams(pydantic.BaseModel): ) +class WindowParams(pydantic.BaseModel): + """Parameters for windowing and aggregation.""" + + mode: WindowMode = pydantic.Field( + default=WindowMode.latest, + description="Extraction mode: 'latest' for single frame (typically accumulated " + "for 1 second), 'window' for aggregation over multiple frames.", + title="Mode", + ) + window_duration_seconds: float = pydantic.Field( + default=1.0, + description="Time duration to aggregate in window mode (seconds).", + title="Window Duration (s)", + ge=0.1, + le=60.0, + ) + aggregation: WindowAggregation = pydantic.Field( + default=WindowAggregation.auto, + description=( + "Aggregation method for window mode. 'auto' uses 'nansum' for " + "counts (unit='counts') and 'nanmean' otherwise." + ), + title="Aggregation", + ) + + class PlotParamsBase(pydantic.BaseModel): """Base class for plot parameters.""" @@ -127,6 +180,10 @@ class PlotParamsBase(pydantic.BaseModel): class PlotParams1d(PlotParamsBase): """Common parameters for 1d plots.""" + window: WindowParams = pydantic.Field( + default_factory=WindowParams, + description="Windowing and aggregation options.", + ) plot_scale: PlotScaleParams = pydantic.Field( default_factory=PlotScaleParams, description="Scaling options for the plot axes.", @@ -136,6 +193,10 @@ class PlotParams1d(PlotParamsBase): class PlotParams2d(PlotParamsBase): """Common parameters for 2d plots.""" + window: WindowParams = pydantic.Field( + default_factory=WindowParams, + description="Windowing and aggregation options.", + ) plot_scale: PlotScaleParams2d = pydantic.Field( default_factory=PlotScaleParams2d, description="Scaling options for the plot and color axes.", @@ -145,6 +206,10 @@ class PlotParams2d(PlotParamsBase): class PlotParams3d(PlotParamsBase): """Parameters for 3D slicer plots.""" + window: WindowParams = pydantic.Field( + default_factory=WindowParams, + description="Windowing and aggregation options.", + ) plot_scale: PlotScaleParams2d = pydantic.Field( default_factory=PlotScaleParams2d, description="Scaling options for the plot axes and color.", @@ -170,3 +235,55 @@ class PlotParamsROIDetector(PlotParams2d): default_factory=ROIOptions, description="Options for ROI selection and display.", ) + + +def create_extractors_from_params( + keys: list[ResultKey], + window: WindowParams | None, + spec: PlotterSpec | None = None, +) -> dict[ResultKey, UpdateExtractor]: + """ + Create extractors based on plotter spec and window configuration. + + Parameters + ---------- + keys: + Result keys to create extractors for. + window: + Window parameters for extraction mode and aggregation. + If None, falls back to LatestValueExtractor. + spec: + Optional plotter specification. If provided and contains a required + extractor, that extractor type is used. + + Returns + ------- + : + Dictionary mapping result keys to extractor instances. + """ + # Import here to avoid circular imports at module level + from .extractors import ( + LatestValueExtractor, + WindowAggregatingExtractor, + ) + + if spec is not None and spec.data_requirements.required_extractor is not None: + # Plotter requires specific extractor (e.g., TimeSeriesPlotter) + extractor_type = spec.data_requirements.required_extractor + return {key: extractor_type() for key in keys} + + # No fixed requirement - check if window params provided + if window is not None: + if window.mode == WindowMode.latest: + return {key: LatestValueExtractor() for key in keys} + else: # mode == WindowMode.window + return { + key: WindowAggregatingExtractor( + window_duration_seconds=window.window_duration_seconds, + aggregation=window.aggregation, + ) + for key in keys + } + + # Fallback to latest value extractor + return {key: LatestValueExtractor() for key in keys} diff --git a/src/ess/livedata/dashboard/plotting.py b/src/ess/livedata/dashboard/plotting.py index bcd145963..f2b1fabc8 100644 --- a/src/ess/livedata/dashboard/plotting.py +++ b/src/ess/livedata/dashboard/plotting.py @@ -11,8 +11,14 @@ import pydantic import scipp as sc +from .extractors import FullHistoryExtractor, UpdateExtractor from .plot_params import PlotParamsROIDetector -from .plots import ImagePlotter, LinePlotter, Plotter, SlicerPlotter +from .plots import ( + ImagePlotter, + LinePlotter, + Plotter, + SlicerPlotter, +) from .scipp_to_holoviews import _all_coords_evenly_spaced @@ -22,6 +28,7 @@ class DataRequirements: min_dims: int max_dims: int + required_extractor: type[UpdateExtractor] | None = None required_coords: list[str] = field(default_factory=list) multiple_datasets: bool = True custom_validators: list[Callable[[sc.DataArray], bool]] = field( @@ -164,6 +171,20 @@ def create_plotter(self, name: str, params: pydantic.BaseModel) -> Plotter: ) +plotter_registry.register_plotter( + name='timeseries', + title='Timeseries', + description='Plot the temporal evolution of scalar values as line plots.', + data_requirements=DataRequirements( + min_dims=0, + max_dims=0, + multiple_datasets=True, + required_extractor=FullHistoryExtractor, + ), + factory=LinePlotter.from_params, +) + + plotter_registry.register_plotter( name='slicer', title='3D Slicer', @@ -208,10 +229,6 @@ def _roi_detector_plotter_factory(params: PlotParamsROIDetector) -> Plotter: 'Backspace while the mouse is within the plot area.' '' ), - data_requirements=DataRequirements( - min_dims=2, - max_dims=2, - multiple_datasets=True, - ), + data_requirements=DataRequirements(min_dims=2, max_dims=2, multiple_datasets=True), factory=_roi_detector_plotter_factory, ) diff --git a/src/ess/livedata/dashboard/plotting_controller.py b/src/ess/livedata/dashboard/plotting_controller.py index d8949e194..099748acd 100644 --- a/src/ess/livedata/dashboard/plotting_controller.py +++ b/src/ess/livedata/dashboard/plotting_controller.py @@ -19,6 +19,7 @@ from .config_store import ConfigStore from .configuration_adapter import ConfigurationState from .job_service import JobService +from .plot_params import create_extractors_from_params from .plotting import PlotterSpec, plotter_registry from .roi_detector_plot_factory import ROIDetectorPlotFactory from .roi_publisher import ROIPublisher @@ -270,20 +271,21 @@ def create_plot( plot_name=plot_name, params=params, ) - items = { + # Build result keys for all sources + keys = [ self.get_result_key( job_number=job_number, source_name=source_name, output_name=output_name - ): self._job_service.job_data[job_number][source_name][output_name] + ) for source_name in source_names - } + ] # Special case for roi_detector: call factory once per detector if plot_name == 'roi_detector': plot_components = [ self._roi_detector_plot_factory.create_roi_detector_plot_components( - detector_key=key, detector_data=data, params=params + detector_key=key, params=params ) - for key, data in items.items() + for key in keys ] # Each component returns (detector_with_boxes, roi_spectrum, plot_state) # Flatten detector and spectrum plots into a layout with 2 columns @@ -292,11 +294,16 @@ def create_plot( plots.extend([detector_with_boxes, roi_spectrum]) return hv.Layout(plots).cols(2).opts(shared_axes=False) - pipe = self._stream_manager.make_merging_stream(items) + # Create extractors based on plotter requirements and params + spec = plotter_registry.get_spec(plot_name) + window = getattr(params, 'window', None) + extractors = create_extractors_from_params(keys, window, spec) + + pipe = self._stream_manager.make_merging_stream(extractors) plotter = plotter_registry.create_plotter(plot_name, params=params) - # Initialize plotter with initial data to determine kdims - plotter.initialize_from_data(items) + # Initialize plotter with extracted data from pipe to determine kdims + plotter.initialize_from_data(pipe.data) # Create DynamicMap with kdims (None if plotter doesn't use them) dmap = hv.DynamicMap(plotter, streams=[pipe], kdims=plotter.kdims, cache_size=1) diff --git a/src/ess/livedata/dashboard/roi_detector_plot_factory.py b/src/ess/livedata/dashboard/roi_detector_plot_factory.py index c33a11c73..a6e7e12ea 100644 --- a/src/ess/livedata/dashboard/roi_detector_plot_factory.py +++ b/src/ess/livedata/dashboard/roi_detector_plot_factory.py @@ -19,7 +19,12 @@ DataSubscriber, MergingStreamAssembler, ) -from .plot_params import LayoutParams, PlotParamsROIDetector +from .extractors import LatestValueExtractor +from .plot_params import ( + LayoutParams, + PlotParamsROIDetector, + create_extractors_from_params, +) from .plots import ImagePlotter, LinePlotter, PlotAspect, PlotAspectType from .roi_publisher import ROIPublisher from .stream_manager import StreamManager @@ -530,16 +535,18 @@ def __init__(self, callback): def send(self, data): self.callback(data) - roi_pipe = ROIReadbackPipe(on_roi_data_update) + def roi_pipe_factory(data): + """Factory function to create ROIReadbackPipe with callback.""" + return ROIReadbackPipe(on_roi_data_update) assembler = MergingStreamAssembler({roi_readback_key}) - subscriber = DataSubscriber(assembler, roi_pipe) + extractors = {roi_readback_key: LatestValueExtractor()} + subscriber = DataSubscriber(assembler, roi_pipe_factory, extractors) self._stream_manager.data_service.register_subscriber(subscriber) def create_roi_detector_plot_components( self, detector_key: ResultKey, - detector_data: sc.DataArray, params: PlotParamsROIDetector, ) -> tuple[hv.DynamicMap, hv.DynamicMap, ROIPlotState]: """ @@ -556,12 +563,14 @@ def create_roi_detector_plot_components( Initial ROI configurations are automatically loaded from DataService via the ROI readback subscription if available. + The detector data must be present in DataService at detector_key before + calling this method, as it will be accessed via extractors. + Parameters ---------- detector_key: - ResultKey identifying the detector output. - detector_data: - Initial data for the detector plot. + ResultKey identifying the detector output. The detector data must + already be present in DataService at this key. params: The plotter parameters (PlotParamsROIDetector). @@ -573,11 +582,11 @@ def create_roi_detector_plot_components( if not isinstance(params, PlotParamsROIDetector): raise TypeError("roi_detector requires PlotParamsROIDetector") - detector_items = {detector_key: detector_data} # FIXME: Memory leak - subscribers registered via stream_manager are never # unregistered. When this plot is closed, the subscriber remains in # DataService._subscribers, preventing garbage collection of plot components. - merged_detector_pipe = self._stream_manager.make_merging_stream(detector_items) + extractors = {detector_key: LatestValueExtractor()} + merged_detector_pipe = self._stream_manager.make_merging_stream(extractors) detector_plotter = ImagePlotter( value_margin_factor=0.1, @@ -585,7 +594,8 @@ def create_roi_detector_plot_components( aspect_params=params.plot_aspect, scale_opts=params.plot_scale, ) - detector_plotter.initialize_from_data(detector_items) + # Use extracted data from pipe for plotter initialization + detector_plotter.initialize_from_data(merged_detector_pipe.data) detector_dmap = hv.DynamicMap( detector_plotter, streams=[merged_detector_pipe], cache_size=1 @@ -643,7 +653,8 @@ def make_request_boxes(data: list): source=request_dmap, num_objects=max_roi_count, data=initial_box_data ) - # Extract coordinate units + # Extract coordinate units from the extracted detector data in pipe + detector_data = merged_detector_pipe.data[detector_key] x_dim, y_dim = detector_data.dims[1], detector_data.dims[0] x_unit = self._extract_unit_for_dim(detector_data, x_dim) y_unit = self._extract_unit_for_dim(detector_data, y_dim) @@ -744,9 +755,8 @@ def _create_roi_spectrum_plot( # FIXME: Memory leak - subscribers registered via stream_manager are never # unregistered. When this plot is closed, the subscriber remains in # DataService._subscribers, preventing garbage collection of plot components. - spectrum_pipe = self._stream_manager.make_merging_stream_from_keys( - spectrum_keys - ) + extractors = create_extractors_from_params(spectrum_keys, params.window) + spectrum_pipe = self._stream_manager.make_merging_stream(extractors) spectrum_plotter = LinePlotter( value_margin_factor=0.1, diff --git a/src/ess/livedata/dashboard/stream_manager.py b/src/ess/livedata/dashboard/stream_manager.py index d614c987b..269624718 100644 --- a/src/ess/livedata/dashboard/stream_manager.py +++ b/src/ess/livedata/dashboard/stream_manager.py @@ -4,13 +4,14 @@ Utilities for connecting subscribers to :py:class:`DataService` """ -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Any, Generic, TypeVar from ess.livedata.config.workflow_spec import ResultKey from .data_service import DataService from .data_subscriber import DataSubscriber, MergingStreamAssembler, Pipe +from .extractors import UpdateExtractor P = TypeVar('P', bound=Pipe) @@ -27,31 +28,21 @@ def __init__( self.data_service = data_service self._pipe_factory = pipe_factory - def make_merging_stream(self, items: dict[ResultKey, Any]) -> P: - """Create a merging stream for the given set of data keys.""" - assembler = MergingStreamAssembler(set(items)) - pipe = self._pipe_factory(items) - subscriber = DataSubscriber(assembler, pipe) - self.data_service.register_subscriber(subscriber) - return pipe - - def make_merging_stream_from_keys( + def make_merging_stream( self, - keys: list[ResultKey], + keys: Sequence[ResultKey] | dict[ResultKey, UpdateExtractor], assembler_factory: Callable[[set[ResultKey]], Any] = MergingStreamAssembler, ) -> P: """ - Create a merging stream for the given result keys, starting with no data. + Create a merging stream for the given result keys. - This is useful when you want to subscribe to keys that may not have data yet. - The pipe is initialized only with available data (which may result in an empty - dictionary), and will receive updates as data becomes available for the - subscribed keys. + The pipe is created lazily on first trigger with correctly extracted data. Parameters ---------- keys: - List of result keys to subscribe to. + Either a sequence of result keys (uses LatestValueExtractor for all) + or a dict mapping keys to their specific UpdateExtractor instances. assembler_factory: Optional callable that creates an assembler from a set of keys. Use functools.partial to bind additional arguments (e.g., filter_fn). @@ -61,10 +52,18 @@ def make_merging_stream_from_keys( : A pipe that will receive merged data updates for the given keys. """ - assembler = assembler_factory(set(keys)) - pipe = self._pipe_factory( - {key: self.data_service[key] for key in keys if key in self.data_service} - ) - subscriber = DataSubscriber(assembler, pipe) + from .extractors import LatestValueExtractor + + if isinstance(keys, dict): + # Dict provided: keys are dict keys, extractors are dict values + keys_set = set(keys.keys()) + extractors = keys + else: + # Sequence provided: use default LatestValueExtractor for all keys + keys_set = set(keys) + extractors = {key: LatestValueExtractor() for key in keys_set} + + assembler = assembler_factory(keys_set) + subscriber = DataSubscriber(assembler, self._pipe_factory, extractors) self.data_service.register_subscriber(subscriber) - return pipe + return subscriber.pipe diff --git a/src/ess/livedata/dashboard/temporal_buffer_manager.py b/src/ess/livedata/dashboard/temporal_buffer_manager.py new file mode 100644 index 000000000..f2519d912 --- /dev/null +++ b/src/ess/livedata/dashboard/temporal_buffer_manager.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +"""Buffer manager with temporal buffer support.""" + +from __future__ import annotations + +import logging +from collections.abc import Hashable, Iterator, Mapping, Sequence +from dataclasses import dataclass, field +from typing import Generic, TypeVar + +import scipp as sc + +from .extractors import LatestValueExtractor, UpdateExtractor +from .temporal_buffers import BufferProtocol, SingleValueBuffer, TemporalBuffer + +logger = logging.getLogger(__name__) + +K = TypeVar('K', bound=Hashable) + + +@dataclass +class _BufferState: + """Internal state for a managed buffer.""" + + buffer: BufferProtocol[sc.DataArray] + extractors: list[UpdateExtractor] = field( + default_factory=list + ) # Stored as list internally + + +class TemporalBufferManager(Mapping[K, BufferProtocol[sc.DataArray]], Generic[K]): + """ + Manages buffers, switching between SingleValueBuffer and TemporalBuffer. + + Decides buffer type based on extractors: + - All LatestValueExtractor → SingleValueBuffer (efficient) + - Otherwise → TemporalBuffer (temporal data with time dimension) + + Implements Mapping interface for read-only dictionary-like access to buffers. + Use get_buffered_data() for convenient access to buffered data. + """ + + def __init__(self) -> None: + """Initialize TemporalBufferManager.""" + self._states: dict[K, _BufferState] = {} + + def __getitem__(self, key: K) -> BufferProtocol[sc.DataArray]: + """Return the buffer for a key.""" + return self._states[key].buffer + + def __iter__(self) -> Iterator[K]: + """Iterate over keys.""" + return iter(self._states) + + def __len__(self) -> int: + """Return number of buffers.""" + return len(self._states) + + def get_buffered_data(self, key: K) -> sc.DataArray | None: + """ + Get data from buffer if available. + + Returns None if buffer doesn't exist or if buffer is empty. + Never raises KeyError - treats "buffer not found" and "buffer empty" + the same way for convenience. + + Parameters + ---------- + key: + Key identifying the buffer. + + Returns + ------- + : + Buffered data, or None if unavailable. + """ + if key not in self._states: + return None + return self._states[key].buffer.get() + + def create_buffer(self, key: K, extractors: Sequence[UpdateExtractor]) -> None: + """ + Create a buffer with appropriate type based on extractors. + + Parameters + ---------- + key: + Key to identify this buffer. + extractors: + Sequence of extractors that will use this buffer. + """ + if key in self._states: + raise ValueError(f"Buffer with key {key} already exists") + + buffer = self._create_buffer_for_extractors(extractors) + self._update_buffer_requirements(buffer, extractors) + state = _BufferState(buffer=buffer, extractors=list(extractors)) + self._states[key] = state + + def update_buffer(self, key: K, data: sc.DataArray) -> None: + """ + Update buffer with new data. + + Parameters + ---------- + key: + Key identifying the buffer to update. + data: + New data to add. + """ + if key not in self._states: + raise KeyError(f"No buffer found for key {key}") + + state = self._states[key] + state.buffer.add(data) + + def add_extractor(self, key: K, extractor: UpdateExtractor) -> None: + """ + Register additional extractor for an existing buffer. + + May trigger buffer type switch with data migration: + - Single→Temporal: Existing data is copied to the new buffer + - Temporal→Single: Last time slice is copied to the new buffer + - Other transitions: Data is discarded + + Parameters + ---------- + key: + Key identifying the buffer to add extractor to. + extractor: + New extractor that will use this buffer. + """ + state = self._states[key] + state.extractors.append(extractor) + self._reconfigure_buffer_if_needed(key, state) + + def set_extractors(self, key: K, extractors: Sequence[UpdateExtractor]) -> None: + """ + Replace all extractors for an existing buffer. + + May trigger buffer type switch with data migration: + - Single→Temporal: Existing data is copied to the new buffer + - Temporal→Single: Last time slice is copied to the new buffer + - Other transitions: Data is discarded + + Useful for reconfiguring buffers when subscribers change, e.g., when + a subscriber is removed and we need to downgrade from Temporal to + SingleValue buffer. + + Parameters + ---------- + key: + Key identifying the buffer to update. + extractors: + New sequence of extractors that will use this buffer. + """ + state = self._states[key] + state.extractors = list(extractors) + self._reconfigure_buffer_if_needed(key, state) + + def _reconfigure_buffer_if_needed(self, key: K, state: _BufferState) -> None: + """ + Check if buffer type needs to change and handle migration. + + Parameters + ---------- + key: + Key identifying the buffer. + state: + Buffer state to reconfigure. + """ + # Check if we need to switch buffer type + new_buffer = self._create_buffer_for_extractors(state.extractors) + if not isinstance(new_buffer, type(state.buffer)): + # Handle data migration for Single->Temporal transition + if isinstance(state.buffer, SingleValueBuffer) and isinstance( + new_buffer, TemporalBuffer + ): + logger.info( + "Switching buffer type from %s to %s for key %s (copying data)", + type(state.buffer).__name__, + type(new_buffer).__name__, + key, + ) + # Copy existing data to new buffer + old_data = state.buffer.get() + if old_data is not None: + new_buffer.add(old_data) + # Handle data migration for Temporal->Single transition + elif isinstance(state.buffer, TemporalBuffer) and isinstance( + new_buffer, SingleValueBuffer + ): + logger.info( + "Switching buffer type from %s to %s for key %s" + " (copying last slice)", + type(state.buffer).__name__, + type(new_buffer).__name__, + key, + ) + # Copy last slice to new buffer + old_data = state.buffer.get() + if old_data is not None and 'time' in old_data.dims: + last_slice = old_data['time', -1] + new_buffer.add(last_slice) + else: + logger.info( + "Switching buffer type from %s to %s for key %s" + " (discarding old data)", + type(state.buffer).__name__, + type(new_buffer).__name__, + key, + ) + state.buffer = new_buffer + + # Update buffer requirements + self._update_buffer_requirements(state.buffer, state.extractors) + + def delete_buffer(self, key: K) -> None: + """ + Delete a buffer and its associated state. + + Parameters + ---------- + key: + Key identifying the buffer to delete. + """ + if key in self._states: + del self._states[key] + + def _create_buffer_for_extractors( + self, extractors: Sequence[UpdateExtractor] + ) -> BufferProtocol[sc.DataArray]: + """ + Create appropriate buffer type based on extractors. + + If all extractors are LatestValueExtractor, use SingleValueBuffer. + Otherwise, use TemporalBuffer. + + Parameters + ---------- + extractors: + Sequence of extractors that will use the buffer. + + Returns + ------- + : + New buffer instance of appropriate type. + """ + if not extractors: + # No extractors - default to SingleValueBuffer + return SingleValueBuffer() + + # Check if all extractors are LatestValueExtractor + all_latest = all(isinstance(e, LatestValueExtractor) for e in extractors) + + if all_latest: + return SingleValueBuffer() + else: + return TemporalBuffer() + + def _update_buffer_requirements( + self, + buffer: BufferProtocol[sc.DataArray], + extractors: Sequence[UpdateExtractor], + ) -> None: + """ + Update buffer requirements based on extractors. + + Computes the maximum required timespan from all extractors and sets it + on the buffer. + + Parameters + ---------- + buffer: + The buffer to update. + extractors: + Sequence of extractors to gather requirements from. + """ + # Compute maximum required timespan + if extractors: + max_timespan = max(e.get_required_timespan() for e in extractors) + buffer.set_required_timespan(max_timespan) + logger.debug( + "Set buffer required timespan to %.2f seconds (from %d extractors)", + max_timespan, + len(extractors), + ) diff --git a/src/ess/livedata/dashboard/temporal_buffers.py b/src/ess/livedata/dashboard/temporal_buffers.py new file mode 100644 index 000000000..3da5967c6 --- /dev/null +++ b/src/ess/livedata/dashboard/temporal_buffers.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +"""Temporal buffer implementations for BufferManager.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +import scipp as sc + +T = TypeVar('T') + + +class BufferProtocol(ABC, Generic[T]): + """Common interface for all buffer types.""" + + @abstractmethod + def add(self, data: T) -> None: + """ + Add new data to the buffer. + + Parameters + ---------- + data: + New data to add to the buffer. + """ + + @abstractmethod + def get(self) -> T | None: + """ + Retrieve current buffer contents. + + Returns + ------- + : + Current buffer contents, or None if empty. + """ + + @abstractmethod + def clear(self) -> None: + """Clear all data from the buffer.""" + + @abstractmethod + def set_required_timespan(self, seconds: float) -> None: + """ + Set the required timespan for the buffer. + + Parameters + ---------- + seconds: + Required timespan in seconds. + """ + + @abstractmethod + def get_required_timespan(self) -> float: + """ + Get the required timespan for the buffer. + + Returns + ------- + : + Required timespan in seconds. + """ + + @abstractmethod + def set_max_memory(self, max_bytes: int) -> None: + """ + Set the maximum memory usage for the buffer. + + Parameters + ---------- + max_bytes: + Maximum memory usage in bytes (approximate). + """ + + +class SingleValueBuffer(BufferProtocol[T]): + """ + Buffer that stores only the latest value. + + Used when only LatestValueExtractor is present for efficiency. + """ + + def __init__(self) -> None: + """Initialize empty single value buffer.""" + self._data: T | None = None + self._max_memory: int | None = None + self._required_timespan: float = 0.0 + + def add(self, data: T) -> None: + """Store the latest value, replacing any previous value.""" + self._data = data + + def get(self) -> T | None: + """Return the stored value.""" + return self._data + + def clear(self) -> None: + """Clear the stored value.""" + self._data = None + + def set_required_timespan(self, seconds: float) -> None: + """Set required timespan (unused for SingleValueBuffer).""" + self._required_timespan = seconds + + def get_required_timespan(self) -> float: + """Get the required timespan.""" + return self._required_timespan + + def set_max_memory(self, max_bytes: int) -> None: + """Set max memory (unused for SingleValueBuffer).""" + self._max_memory = max_bytes + + +class VariableBuffer: + """ + Buffer managing sc.Variable data with dynamic sizing along a concat dimension. + + Handles appending data slices, capacity management with lazy expansion, + and dropping old data from the start. + """ + + def __init__( + self, + *, + data: sc.Variable, + max_capacity: int, + concat_dim: str = 'time', + ) -> None: + """ + Initialize variable buffer with initial data. + + Parameters + ---------- + data: + Initial data to store. Defines buffer structure (dims, dtype, unit). + max_capacity: + Maximum allowed size along concat_dim. + concat_dim: + Dimension along which to concatenate data. + """ + self._concat_dim = concat_dim + self._max_capacity = max_capacity + + # Allocate minimal initial buffer + initial_size = min(16, max_capacity) + self._buffer = self._allocate_buffer(data, initial_size) + self._size = 0 + + # Delegate to append + if not self.append(data): + raise ValueError(f"Initial data exceeds max_capacity {max_capacity}") + + @property + def size(self) -> int: + """Get number of valid elements.""" + return self._size + + @property + def capacity(self) -> int: + """Get currently allocated buffer size.""" + return self._buffer.sizes[self._concat_dim] + + @property + def max_capacity(self) -> int: + """Get maximum capacity limit.""" + return self._max_capacity + + def append(self, data: sc.Variable) -> bool: + """ + Append data to the buffer. + + Parameters + ---------- + data: + Data to append. May or may not have concat_dim dimension. + If concat_dim is present, appends all slices along that dimension. + Otherwise, treats as single slice. + + Returns + ------- + : + True if successful, False if would exceed max_capacity. + """ + # Determine how many elements we're adding + if self._concat_dim in data.dims: + n_incoming = data.sizes[self._concat_dim] + else: + n_incoming = 1 + + # Check max_capacity + if self._size + n_incoming > self._max_capacity: + return False + + # Expand to fit all incoming data + if self._size + n_incoming > self.capacity: + self._expand_to_fit(self._size + n_incoming) + + # Write data + if self._concat_dim in data.dims: + # Thick slice + end = self._size + n_incoming + self._buffer[self._concat_dim, self._size : end] = data + else: + # Single slice + self._buffer[self._concat_dim, self._size] = data + + self._size += n_incoming + return True + + def get(self) -> sc.Variable: + """ + Get buffer contents up to current size. + + Returns + ------- + : + Valid buffer data (0:size). + """ + return self._buffer[self._concat_dim, : self._size] + + def drop(self, index: int) -> None: + """ + Drop data from start up to (but not including) index. + + Remaining valid data is moved to the start of the buffer. + + Parameters + ---------- + index: + Index from start until which to drop (exclusive). + """ + if index <= 0: + return + + if index >= self._size: + # Dropping everything + self._size = 0 + return + + # Move remaining data to start + n_remaining = self._size - index + self._buffer[self._concat_dim, :n_remaining] = self._buffer[ + self._concat_dim, index : self._size + ] + self._size = n_remaining + + def _allocate_buffer(self, template: sc.Variable, size: int) -> sc.Variable: + """ + Allocate new buffer based on template variable structure. + + Makes concat_dim the outer (first) dimension for efficient contiguous writes. + If template already has concat_dim, preserves its position. + """ + if self._concat_dim in template.dims: + # Template has concat_dim - preserve dimension order + sizes = template.sizes + sizes[self._concat_dim] = size + else: + # Template doesn't have concat_dim - make it the outer dimension + sizes = {self._concat_dim: size} + sizes.update(template.sizes) + + return sc.empty( + sizes=sizes, + dtype=template.dtype, + unit=template.unit, + with_variances=template.variances is not None, + ) + + def _expand_to_fit(self, min_size: int) -> None: + """Expand buffer to accommodate at least min_size elements.""" + current_allocated = self._buffer.sizes[self._concat_dim] + while current_allocated < min_size: + current_allocated = min(self._max_capacity, current_allocated * 2) + + if current_allocated > self._buffer.sizes[self._concat_dim]: + new_buffer = self._allocate_buffer(self._buffer, current_allocated) + new_buffer[self._concat_dim, : self._size] = self._buffer[ + self._concat_dim, : self._size + ] + self._buffer = new_buffer + + +class TemporalBuffer(BufferProtocol[sc.DataArray]): + """ + Buffer that maintains temporal data with a time dimension. + + Uses VariableBuffer for efficient appending without expensive concat operations. + Validates that non-time coords and masks remain constant across all added data. + """ + + def __init__(self) -> None: + """Initialize empty temporal buffer.""" + self._data_buffer: VariableBuffer | None = None + self._time_buffer: VariableBuffer | None = None + self._reference: sc.DataArray | None = None + self._max_memory: int | None = None + self._required_timespan: float = 0.0 + + def add(self, data: sc.DataArray) -> None: + """ + Add data to the buffer, appending along time dimension. + + Parameters + ---------- + data: + New data to add. Must have a 'time' coordinate. + + Raises + ------ + ValueError + If data does not have a 'time' coordinate or exceeds buffer capacity. + """ + if 'time' not in data.coords: + raise ValueError("TemporalBuffer requires data with 'time' coordinate") + + # First data or metadata mismatch - initialize/reset buffers + if self._data_buffer is None or not self._metadata_matches(data): + self._initialize_buffers(data) + return + + # Try to append to existing buffers + if not self._data_buffer.append(data.data): + # Failed - trim old data and retry + self._trim_to_timespan(data) + if not self._data_buffer.append(data.data): + raise ValueError("Data exceeds buffer capacity even after trimming") + + # Time buffer should succeed (buffers kept in sync by trimming) + if not self._time_buffer.append(data.coords['time']): + raise RuntimeError("Time buffer append failed unexpectedly") + + def get(self) -> sc.DataArray | None: + """Return the complete buffer.""" + if self._data_buffer is None: + return None + + # Reconstruct DataArray from buffers and reference metadata + data_var = self._data_buffer.get() + time_coord = self._time_buffer.get() + + coords = {'time': time_coord} + coords.update(self._reference.coords) + + masks = dict(self._reference.masks) + + return sc.DataArray(data=data_var, coords=coords, masks=masks) + + def clear(self) -> None: + """Clear all buffered data.""" + self._data_buffer = None + self._time_buffer = None + self._reference = None + + def set_required_timespan(self, seconds: float) -> None: + """Set the required timespan for the buffer.""" + self._required_timespan = seconds + + def get_required_timespan(self) -> float: + """Get the required timespan for the buffer.""" + return self._required_timespan + + def set_max_memory(self, max_bytes: int) -> None: + """Set the maximum memory usage for the buffer.""" + self._max_memory = max_bytes + + def _initialize_buffers(self, data: sc.DataArray) -> None: + """Initialize buffers with first data, storing reference metadata.""" + # Store reference as slice at time=0 without time coord + if 'time' in data.dims: + self._reference = data['time', 0].drop_coords('time') + else: + self._reference = data.drop_coords('time') + + # Calculate max_capacity from memory limit + if 'time' in data.dims: + bytes_per_element = data.data.values.nbytes / data.sizes['time'] + else: + bytes_per_element = data.data.values.nbytes + + if self._max_memory is not None: + max_capacity = int(self._max_memory / bytes_per_element) + else: + max_capacity = 10000 # Default large capacity + + # Create buffers + self._data_buffer = VariableBuffer( + data=data.data, max_capacity=max_capacity, concat_dim='time' + ) + self._time_buffer = VariableBuffer( + data=data.coords['time'], max_capacity=max_capacity, concat_dim='time' + ) + + def _trim_to_timespan(self, new_data: sc.DataArray) -> None: + """Trim buffer to keep only data within required timespan.""" + if self._required_timespan < 0: + return + + if self._required_timespan == 0.0: + # Keep only the latest value - drop all existing data + drop_count = self._data_buffer.size + self._data_buffer.drop(drop_count) + self._time_buffer.drop(drop_count) + return + + # Get latest time from new data + if 'time' in new_data.dims: + latest_time = new_data.coords['time'][-1] + else: + latest_time = new_data.coords['time'] + + # Calculate cutoff time + cutoff = latest_time - sc.scalar(self._required_timespan, unit='s') + + # Find first index to keep + time_coord = self._time_buffer.get() + keep_mask = time_coord >= cutoff + + if not keep_mask.values.any(): + # All data is old, drop everything + drop_count = self._data_buffer.size + else: + # Find first True index + drop_count = int(keep_mask.values.argmax()) + + # Trim both buffers by same amount to keep them in sync + self._data_buffer.drop(drop_count) + self._time_buffer.drop(drop_count) + + def _metadata_matches(self, data: sc.DataArray) -> bool: + """Check if incoming data's metadata matches stored reference metadata.""" + # Extract comparable slice from incoming data + if 'time' in data.dims: + new = data['time', 0] + else: + new = data + + # Create template with reference data but incoming metadata + template = new.assign(self._reference.data).drop_coords('time') + + return sc.identical(self._reference, template) diff --git a/src/ess/livedata/handlers/detector_view.py b/src/ess/livedata/handlers/detector_view.py index b126ee60f..460bc94db 100644 --- a/src/ess/livedata/handlers/detector_view.py +++ b/src/ess/livedata/handlers/detector_view.py @@ -54,6 +54,7 @@ def __init__( self._toa_edges = params.toa_edges.get_edges() self._rois_updated = False # Track ROI updates at workflow level self._roi_mapper = get_roi_mapper() + self._current_start_time: int | None = None def apply_toa_range(self, data: sc.DataArray) -> sc.DataArray: if not self._use_toa_range: @@ -64,7 +65,9 @@ def apply_toa_range(self, data: sc.DataArray) -> sc.DataArray: # into a coordinate, since scipp does not support filtering on data variables. return data.bins.assign_coords(toa=data.bins.data).bins['toa', low:high] - def accumulate(self, data: dict[Hashable, Any]) -> None: + def accumulate( + self, data: dict[Hashable, Any], *, start_time: int, end_time: int + ) -> None: """ Add data to the accumulator. @@ -74,6 +77,10 @@ def accumulate(self, data: dict[Hashable, Any]) -> None: Data to be added. Expected to contain detector event data and optionally ROI configuration. Detector data is assumed to be ev44 data that was passed through :py:class:`GroupIntoPixels`. + start_time: + Start time of the data window in nanoseconds since epoch. + end_time: + End time of the data window in nanoseconds since epoch. """ # Check for ROI configuration update (auxiliary data) # Stream name is 'roi' (from 'roi_rectangle' after job_id prefix stripped) @@ -92,6 +99,11 @@ def accumulate(self, data: dict[Hashable, Any]) -> None: raise ValueError( "DetectorViewProcessor expects exactly one detector data item." ) + + # Track start time of first detector data since last finalize + if self._current_start_time is None: + self._current_start_time = start_time + raw = next(iter(detector_data.values())) filtered = self.apply_toa_range(raw) self._view.add_events(filtered) @@ -99,6 +111,11 @@ def accumulate(self, data: dict[Hashable, Any]) -> None: roi_state.add_data(raw) def finalize(self) -> dict[str, sc.DataArray]: + if self._current_start_time is None: + raise RuntimeError( + "finalize called without any detector data accumulated via accumulate" + ) + cumulative = self._view.cumulative.copy() # This is a hack to get the current counts. Should be updated once # ess.reduce.live.raw.RollingDetectorView has been modified to support this. @@ -106,6 +123,12 @@ def finalize(self) -> dict[str, sc.DataArray]: if self._previous is not None: current = current - self._previous self._previous = cumulative + + # Add time coord to current result + time_coord = sc.scalar(self._current_start_time, unit='ns') + current = current.assign_coords(time=time_coord) + self._current_start_time = None + result = sc.DataGroup(cumulative=cumulative, current=current) view_result = dict(result * self._inv_weights if self._use_weights else result) @@ -113,7 +136,10 @@ def finalize(self) -> dict[str, sc.DataArray]: for idx, roi_state in self._rois.items(): roi_delta = roi_state.get_delta() - roi_result[self._roi_mapper.current_key(idx)] = roi_delta + # Add time coord to ROI current result + roi_result[self._roi_mapper.current_key(idx)] = roi_delta.assign_coords( + time=time_coord + ) roi_result[self._roi_mapper.cumulative_key(idx)] = ( roi_state.cumulative.copy() ) @@ -139,6 +165,7 @@ def finalize(self) -> dict[str, sc.DataArray]: def clear(self) -> None: self._view.clear_counts() self._previous = None + self._current_start_time = None for roi_state in self._rois.values(): roi_state.clear() diff --git a/src/ess/livedata/handlers/monitor_data_handler.py b/src/ess/livedata/handlers/monitor_data_handler.py index beefac8e9..6e2980980 100644 --- a/src/ess/livedata/handlers/monitor_data_handler.py +++ b/src/ess/livedata/handlers/monitor_data_handler.py @@ -20,15 +20,27 @@ def __init__(self, edges: sc.Variable) -> None: self._event_edges = edges.to(unit='ns').values self._cumulative: sc.DataArray | None = None self._current: sc.DataArray | None = None + self._current_start_time: int | None = None @staticmethod def create_workflow(params: MonitorDataParams) -> Workflow: """Factory method for creating MonitorStreamProcessor from params.""" return MonitorStreamProcessor(edges=params.toa_edges.get_edges()) - def accumulate(self, data: dict[Hashable, sc.DataArray | np.ndarray]) -> None: + def accumulate( + self, + data: dict[Hashable, sc.DataArray | np.ndarray], + *, + start_time: int, + end_time: int, + ) -> None: if len(data) != 1: raise ValueError("MonitorStreamProcessor expects exactly one data item.") + + # Track start time of first data since last finalize + if self._current_start_time is None: + self._current_start_time = start_time + raw = next(iter(data.values())) # Note: In theory we should consider rebinning/histogramming only in finalize(), # but the current plan is to accumulate before/during preprocessing, i.e., @@ -58,17 +70,29 @@ def accumulate(self, data: dict[Hashable, sc.DataArray | np.ndarray]) -> None: def finalize(self) -> dict[Hashable, sc.DataArray]: if self._current is None: raise ValueError("No data has been added") + if self._current_start_time is None: + raise RuntimeError( + "finalize called without any data accumulated via accumulate" + ) + current = self._current if self._cumulative is None: self._cumulative = current else: self._cumulative += current self._current = sc.zeros_like(current) + + # Add time coord to current result + time_coord = sc.scalar(self._current_start_time, unit='ns') + current = current.assign_coords(time=time_coord) + self._current_start_time = None + return {'cumulative': self._cumulative, 'current': current} def clear(self) -> None: self._cumulative = None self._current = None + self._current_start_time = None class MonitorHandlerFactory( diff --git a/src/ess/livedata/handlers/stream_processor_workflow.py b/src/ess/livedata/handlers/stream_processor_workflow.py index b3f218b28..7427d26ae 100644 --- a/src/ess/livedata/handlers/stream_processor_workflow.py +++ b/src/ess/livedata/handlers/stream_processor_workflow.py @@ -43,7 +43,9 @@ def __init__( **kwargs, ) - def accumulate(self, data: dict[str, Any]) -> None: + def accumulate( + self, data: dict[str, Any], *, start_time: int, end_time: int + ) -> None: context = { sciline_key: data[key] for key, sciline_key in self._context_keys.items() diff --git a/src/ess/livedata/handlers/timeseries_handler.py b/src/ess/livedata/handlers/timeseries_handler.py index 5da20d520..5ba857e28 100644 --- a/src/ess/livedata/handlers/timeseries_handler.py +++ b/src/ess/livedata/handlers/timeseries_handler.py @@ -21,25 +21,38 @@ class TimeseriesStreamProcessor(Workflow): def __init__(self) -> None: self._data: sc.DataArray | None = None + self._last_returned_index = 0 @staticmethod def create_workflow() -> Workflow: """Factory method for creating TimeseriesStreamProcessor.""" return TimeseriesStreamProcessor() - def accumulate(self, data: dict[Hashable, sc.DataArray]) -> None: + def accumulate( + self, data: dict[Hashable, sc.DataArray], *, start_time: int, end_time: int + ) -> None: if len(data) != 1: raise ValueError("Timeseries processor expects exactly one data item.") - # Just store the data for forwarding + # Store the full cumulative data (including history from preprocessor) self._data = next(iter(data.values())) def finalize(self) -> dict[str, sc.DataArray]: if self._data is None: raise ValueError("No data has been added") - return {'cumulative': self._data} + + # Return only new data since last finalize to avoid republishing full history + current_size = self._data.sizes['time'] + if self._last_returned_index >= current_size: + raise ValueError("No new data since last finalize") + + result = self._data['time', self._last_returned_index :] + self._last_returned_index = current_size + + return {'delta': result} def clear(self) -> None: self._data = None + self._last_returned_index = 0 class LogdataHandlerFactory(JobBasedPreprocessorFactoryBase[LogData, sc.DataArray]): diff --git a/src/ess/livedata/handlers/workflow_factory.py b/src/ess/livedata/handlers/workflow_factory.py index 45daf23de..8efc70c17 100644 --- a/src/ess/livedata/handlers/workflow_factory.py +++ b/src/ess/livedata/handlers/workflow_factory.py @@ -17,7 +17,9 @@ class Workflow(Protocol): implementations, in particular for non-data-reduction jobs. """ - def accumulate(self, data: dict[str, Any]) -> None: ... + def accumulate( + self, data: dict[str, Any], *, start_time: int, end_time: int + ) -> None: ... def finalize(self) -> dict[str, Any]: ... def clear(self) -> None: ... diff --git a/tests/config/instrument_test.py b/tests/config/instrument_test.py index 795718ef8..a649f1d51 100644 --- a/tests/config/instrument_test.py +++ b/tests/config/instrument_test.py @@ -181,9 +181,15 @@ def test_register_spec_and_attach_factory(self): def simple_processor_factory(source_name: str) -> Workflow: # Return a mock processor for testing class MockProcessor(Workflow): - def __call__(self, *args, **kwargs): + def accumulate(self, data, *, start_time: int, end_time: int) -> None: + pass + + def finalize(self): return {"source": source_name} + def clear(self) -> None: + pass + return MockProcessor() # Attach factory using decorator @@ -354,9 +360,15 @@ class MyParams(pydantic.BaseModel): @handle.attach_factory() def factory(*, params: MyParams) -> Workflow: class MockProcessor(Workflow): - def __call__(self, *args, **kwargs): + def accumulate(self, data, *, start_time: int, end_time: int) -> None: + pass + + def finalize(self): return {"value": params.value} + def clear(self) -> None: + pass + return MockProcessor() # Verify factory was attached diff --git a/tests/core/job_test.py b/tests/core/job_test.py index db62eca60..e53b5ae0e 100644 --- a/tests/core/job_test.py +++ b/tests/core/job_test.py @@ -53,7 +53,9 @@ def __init__(self): self.should_fail_accumulate = False self.should_fail_finalize = False - def accumulate(self, data: dict[str, Any]) -> None: + def accumulate( + self, data: dict[str, Any], *, start_time: int, end_time: int + ) -> None: if self.should_fail_accumulate: raise RuntimeError("Accumulate failure") self.accumulate_calls.append(data.copy()) diff --git a/tests/dashboard/correlation_histogram_test.py b/tests/dashboard/correlation_histogram_test.py new file mode 100644 index 000000000..7a2bf46e6 --- /dev/null +++ b/tests/dashboard/correlation_histogram_test.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +"""Tests for CorrelationHistogramController.""" + +import uuid + +import scipp as sc + +from ess.livedata.config.workflow_spec import JobId, ResultKey, WorkflowId +from ess.livedata.dashboard.correlation_histogram import ( + CorrelationHistogramController, + CorrelationHistogramProcessor, + EdgesWithUnit, +) +from ess.livedata.dashboard.data_service import DataService + + +def make_timeseries_data(values: list[float], times: list[float]) -> sc.DataArray: + """Create a 1D timeseries DataArray with time dimension. + + Note: This is a convenience for tests, but in practice individual timeseries + points arrive as 0D scalars (see make_timeseries_point). + """ + return sc.DataArray( + sc.array(dims=['time'], values=values, unit='counts'), + coords={'time': sc.array(dims=['time'], values=times, unit='s')}, + ) + + +def make_timeseries_point(value: float, time: float) -> sc.DataArray: + """Create a single 0D timeseries point with time coordinate. + + This mimics how real timeseries data arrives - as individual 0D scalars + with a time coord (no time dimension yet). + """ + return sc.DataArray( + sc.scalar(value, unit='counts'), + coords={'time': sc.scalar(time, unit='s')}, + ) + + +def make_non_timeseries_data(value: float) -> sc.DataArray: + """Create a 0D scalar DataArray without time dimension.""" + return sc.DataArray( + sc.scalar(value, unit='counts'), + coords={'x': sc.scalar(10.0, unit='m')}, + ) + + +def make_result_key(source_name: str) -> ResultKey: + """Create a ResultKey for testing.""" + return ResultKey( + workflow_id=WorkflowId( + instrument='test', namespace='test', name='workflow', version=1 + ), + job_id=JobId(source_name=source_name, job_number=uuid.uuid4()), + output_name=None, + ) + + +class TestCorrelationHistogramController: + """Tests for CorrelationHistogramController with buffered DataService.""" + + def test_get_timeseries_with_individual_0d_points(self): + """Test identifying timeseries from individual 0D points. + + In practice, timeseries data arrives as individual 0D scalars with time + coords (no time dimension). The time dimension is only added when the + history extractor concatenates them. This test verifies get_timeseries() + correctly identifies such data. + """ + data_service = DataService[ResultKey, sc.DataArray]() + controller = CorrelationHistogramController(data_service) + + # Add individual 0D timeseries points (realistic scenario) + ts_key = make_result_key('timeseries_stream') + data_service[ts_key] = make_timeseries_point(1.0, 0.0) + data_service[ts_key] = make_timeseries_point(2.0, 1.0) + data_service[ts_key] = make_timeseries_point(3.0, 2.0) + + # Add non-timeseries data + non_ts_key = make_result_key('scalar_data') + data_service[non_ts_key] = make_non_timeseries_data(42.0) + + # get_timeseries should identify the timeseries + timeseries_keys = controller.get_timeseries() + + assert ts_key in timeseries_keys, "Failed to identify 0D point timeseries" + assert ( + non_ts_key not in timeseries_keys + ), "Incorrectly identified scalar as timeseries" + assert len(timeseries_keys) == 1 + + def test_get_timeseries_identifies_buffered_timeseries(self): + """Test identifying timeseries from buffered DataService. + + This test verifies that get_timeseries correctly identifies timeseries + when DataService uses buffers and LatestValueExtractor returns 0D + scalars. + """ + # Create DataService (uses buffers internally) + data_service = DataService[ResultKey, sc.DataArray]() + controller = CorrelationHistogramController(data_service) + + # Add timeseries data - stored in buffers + ts_key1 = make_result_key('timeseries_1') + ts_key2 = make_result_key('timeseries_2') + data_service[ts_key1] = make_timeseries_data([1.0, 2.0, 3.0], [0.0, 1.0, 2.0]) + data_service[ts_key2] = make_timeseries_data([10.0, 20.0], [0.0, 1.0]) + + # Add non-timeseries data + non_ts_key = make_result_key('scalar_data') + data_service[non_ts_key] = make_non_timeseries_data(42.0) + + # Get timeseries - this should find the timeseries keys + timeseries_keys = controller.get_timeseries() + + # Should identify both timeseries, but NOT the scalar + assert ts_key1 in timeseries_keys, "Failed to identify timeseries_1" + assert ts_key2 in timeseries_keys, "Failed to identify timeseries_2" + assert ( + non_ts_key not in timeseries_keys + ), "Incorrectly identified scalar as timeseries" + assert len(timeseries_keys) == 2 + + def test_processor_receives_full_history_from_0d_points(self): + """Test processor receives concatenated history from individual 0D points. + + Realistic scenario: Individual 0D timeseries points arrive after subscriber + registration. The buffer accumulates history only after a FullHistoryExtractor + subscriber is registered, then concatenates subsequent 0D points into 1D. + """ + data_service = DataService[ResultKey, sc.DataArray]() + controller = CorrelationHistogramController(data_service) + + # Add initial 0D points (before subscriber registration) + data_key = make_result_key('data_stream') + coord_key = make_result_key('coord_stream') + + data_service[data_key] = make_timeseries_point(1.0, 0.0) + data_service[coord_key] = make_timeseries_point(10.0, 0.0) + + # Track what data the processor receives + received_data = [] + + def result_callback(_: sc.DataArray) -> None: + """Callback to capture processor results.""" + + # Create processor with edges for binning + edges = EdgesWithUnit(start=5.0, stop=35.0, num_bins=3, unit='counts') + + processor = CorrelationHistogramProcessor( + data_key=data_key, + coord_keys=[coord_key], + edges_params=[edges], + normalize=False, + result_callback=result_callback, + ) + + # Monkey-patch processor.send to capture received data + original_send = processor.send + + def capturing_send(data: dict[ResultKey, sc.DataArray]) -> None: + received_data.append(data) + original_send(data) + + processor.send = capturing_send + + # Register subscriber with FullHistoryExtractor - buffer starts accumulating now + items = { + data_key: data_service[data_key], + coord_key: data_service[coord_key], + } + controller.add_correlation_processor(processor, items) + + # Processor triggered immediately with existing data (1 point each) + assert len(received_data) == 1, "Processor should be triggered on registration" + received = received_data[0] + assert received[data_key].dims == ('time',) + assert received[data_key].sizes['time'] == 1, "Initial data: 1 point" + + # Now add more 0D points - use transaction to batch updates + with data_service.transaction(): + data_service[data_key] = make_timeseries_point(2.0, 1.0) + data_service[coord_key] = make_timeseries_point(20.0, 1.0) + + # Processor should receive accumulated history (2 points) + assert len(received_data) == 2, "Processor triggered on new data" + received = received_data[1] + assert received[data_key].sizes['time'] == 2, "Should have 2 time points" + assert received[coord_key].sizes['time'] == 2 + + # Add another point + with data_service.transaction(): + data_service[data_key] = make_timeseries_point(3.0, 2.0) + data_service[coord_key] = make_timeseries_point(30.0, 2.0) + + # Processor should receive full accumulated history (3 points) + assert len(received_data) == 3, "Processor triggered again" + latest_received = received_data[-1] + assert latest_received[data_key].sizes['time'] == 3, "Should have 3 time points" + assert latest_received[coord_key].sizes['time'] == 3 + + def test_processor_receives_full_history_not_latest(self): + """Test processor receives full timeseries history via extractors. + + The processor needs complete history to compute correlation histograms. + This test verifies that FullHistoryExtractor is used for subscribers. + """ + # Create DataService and controller + data_service = DataService[ResultKey, sc.DataArray]() + controller = CorrelationHistogramController(data_service) + + # Add initial timeseries data + data_key = make_result_key('data_stream') + coord_key = make_result_key('coord_stream') + + data_service[data_key] = make_timeseries_data([1.0, 2.0, 3.0], [0.0, 1.0, 2.0]) + data_service[coord_key] = make_timeseries_data( + [10.0, 20.0, 30.0], [0.0, 1.0, 2.0] + ) + + # Track what data the processor receives + received_data = [] + + def result_callback(_: sc.DataArray) -> None: + """Callback to capture processor results.""" + + # Create processor with edges for binning + edges = EdgesWithUnit(start=5.0, stop=35.0, num_bins=3, unit='counts') + + processor = CorrelationHistogramProcessor( + data_key=data_key, + coord_keys=[coord_key], + edges_params=[edges], + normalize=False, + result_callback=result_callback, + ) + + # Monkey-patch processor.send to capture received data + original_send = processor.send + + def capturing_send(data: dict[ResultKey, sc.DataArray]) -> None: + received_data.append(data) + original_send(data) + + processor.send = capturing_send + + # Add processor - this should register subscriber with FullHistoryExtractor + items = { + data_key: data_service[data_key], + coord_key: data_service[coord_key], + } + controller.add_correlation_processor(processor, items) + + # Processor should have been triggered immediately with existing data + assert len(received_data) == 1, "Processor should be triggered on registration" + + # Verify that processor received FULL history, not just latest value + received = received_data[0] + assert data_key in received + assert coord_key in received + + # Check data dimensions - should be 1D timeseries, not 0D scalar + assert received[data_key].dims == ( + 'time', + ), f"Expected 1D timeseries with time dim, got {received[data_key].dims}" + assert received[data_key].sizes['time'] == 3, "Should have all 3 time points" + + assert received[coord_key].dims == ('time',) + assert received[coord_key].sizes['time'] == 3 + + # Add more data and verify processor gets updated history + data_service[data_key] = make_timeseries_data([4.0], [3.0]) + data_service[coord_key] = make_timeseries_data([40.0], [3.0]) + + # Should have received second update + assert len(received_data) >= 2, "Processor should be triggered on data update" + latest_received = received_data[-1] + + # Latest update should also include full history (now 4 points) + assert latest_received[data_key].sizes['time'] == 4 + assert latest_received[coord_key].sizes['time'] == 4 diff --git a/tests/dashboard/data_service_benchmark.py b/tests/dashboard/data_service_benchmark.py new file mode 100644 index 000000000..46104cc8d --- /dev/null +++ b/tests/dashboard/data_service_benchmark.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +"""Benchmarks for DataService LatestFrame extraction with subscriber notifications.""" + +from __future__ import annotations + +from typing import Any + +import pytest +import scipp as sc + +from ess.livedata.dashboard.data_service import DataService, DataServiceSubscriber +from ess.livedata.dashboard.extractors import LatestValueExtractor + + +class SimpleSubscriber(DataServiceSubscriber[str]): + """Simple subscriber that tracks trigger calls.""" + + def __init__(self, keys: set[str]) -> None: + """Initialize subscriber with given keys.""" + self._keys_set = keys + self.trigger_count = 0 + self.received_updates: list[dict[str, Any]] = [] + super().__init__() + + @property + def keys(self) -> set[str]: + """Return the keys this subscriber depends on.""" + return self._keys_set + + @property + def extractors(self) -> dict[str, LatestValueExtractor]: + """Return extractors for all keys.""" + return {key: LatestValueExtractor() for key in self._keys_set} + + def trigger(self, store: dict[str, Any]) -> None: + """Track trigger calls and received updates.""" + self.trigger_count += 1 + self.received_updates.append(store.copy()) + + +class TestDataServiceBenchmark: + """Benchmarks for DataService with LatestFrame extraction.""" + + @pytest.fixture + def service(self) -> DataService[str, sc.Variable]: + """Create a fresh DataService for each benchmark.""" + return DataService() + + @pytest.fixture + def sample_data(self) -> dict[str, sc.Variable]: + """Create sample data with scipp Variables.""" + return { + 'detector_counts': sc.scalar(100, unit='counts'), + 'monitor_counts': sc.scalar(50, unit='counts'), + 'temperature': sc.scalar(298.15, unit='K'), + } + + def test_update_multiple_keys_with_subscriber( + self, benchmark, service: DataService[str, sc.Variable], sample_data + ): + """Benchmark updating multiple keys with subscriber watching all.""" + subscriber = SimpleSubscriber(set(sample_data.keys())) + service.register_subscriber(subscriber) + + def update_multiple_keys_with_subscriber(): + with service.transaction(): + for key, value in sample_data.items(): + service[key] = value + + benchmark(update_multiple_keys_with_subscriber) + assert len(subscriber.received_updates) > 0 + + def test_update_with_multiple_subscribers_same_key( + self, benchmark, service: DataService[str, sc.Variable] + ): + """Benchmark update with multiple subscribers watching the same key.""" + subscribers = [SimpleSubscriber({'data'}) for _ in range(5)] + for sub in subscribers: + service.register_subscriber(sub) + + data = sc.scalar(42, unit='counts') + + def update_with_subscribers(): + service['data'] = data + + benchmark(update_with_subscribers) + # Each subscriber should have been triggered + for sub in subscribers: + assert sub.trigger_count >= 1 + + def test_update_many_keys_many_subscribers( + self, benchmark, service: DataService[str, sc.Variable] + ): + """Benchmark updating many keys with many subscribers (one per key).""" + keys = [f'key_{i}' for i in range(100)] + # Create one subscriber per key - typical real-world scenario + subscribers = [SimpleSubscriber({key}) for key in keys] + for sub in subscribers: + service.register_subscriber(sub) + + def update_many(): + with service.transaction(): + for i, key in enumerate(keys): + service[key] = sc.scalar(i, unit='counts') + return len(service) + + result = benchmark(update_many) + assert result == 100 + # Each subscriber should have been triggered once + for sub in subscribers: + assert sub.trigger_count >= 1 + + def test_extract_from_large_service( + self, benchmark, service: DataService[str, sc.Variable] + ): + """Benchmark extracting values from large service via subscriber.""" + keys = [f'key_{i}' for i in range(1000)] + + # Populate service + for i, key in enumerate(keys): + service[key] = sc.scalar(i, unit='counts') + + # Subscribe to first 10 keys + watched_keys = set(keys[:10]) + subscriber = SimpleSubscriber(watched_keys) + service.register_subscriber(subscriber) + + # Reset trigger count to avoid counting initialization + subscriber.trigger_count = 0 + + def update_and_extract(): + service[keys[0]] = sc.scalar(999, unit='counts') + + benchmark(update_and_extract) + # Subscriber should have received the update + assert subscriber.trigger_count >= 1 diff --git a/tests/dashboard/data_service_test.py b/tests/dashboard/data_service_test.py index 093d19850..00de47f69 100644 --- a/tests/dashboard/data_service_test.py +++ b/tests/dashboard/data_service_test.py @@ -2,12 +2,23 @@ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) from __future__ import annotations +from collections.abc import Callable from typing import Any import pytest +import scipp as sc -from ess.livedata.dashboard.data_service import DataService +from ess.livedata.dashboard.data_service import DataService, DataServiceSubscriber from ess.livedata.dashboard.data_subscriber import DataSubscriber, Pipe, StreamAssembler +from ess.livedata.dashboard.extractors import LatestValueExtractor + + +def make_test_data(value: int, time: float = 0.0) -> sc.DataArray: + """Create a test DataArray with the given value and time coordinate.""" + return sc.DataArray( + sc.scalar(value, unit='counts'), + coords={'time': sc.scalar(time, unit='s')}, + ) class FakeDataAssembler(StreamAssembler[str]): @@ -20,19 +31,33 @@ def assemble(self, data: dict[str, Any]) -> dict[str, Any]: class FakePipe(Pipe): """Fake pipe for testing.""" - def __init__(self) -> None: + def __init__(self, data: Any = None) -> None: + self.init_data = data self.sent_data: list[dict[str, Any]] = [] def send(self, data: Any) -> None: self.sent_data.append(data) -def create_test_subscriber(keys: set[str]) -> tuple[DataSubscriber[str], FakePipe]: - """Create a test subscriber with the given keys.""" +def create_test_subscriber(keys: set[str]) -> tuple[DataSubscriber[str], Callable]: + """ + Create a test subscriber with the given keys. + + Returns the subscriber and a callable to get the pipe after it's created. + """ assembler = FakeDataAssembler(keys) - pipe = FakePipe() - subscriber = DataSubscriber(assembler, pipe) - return subscriber, pipe + extractors = {key: LatestValueExtractor() for key in keys} + + def pipe_factory(data: Any) -> FakePipe: + return FakePipe(data) + + subscriber = DataSubscriber(assembler, pipe_factory, extractors) + + def get_pipe() -> FakePipe: + """Get the pipe (created on first trigger).""" + return subscriber.pipe + + return subscriber, get_pipe @pytest.fixture @@ -51,53 +76,70 @@ def test_init_creates_empty_service(): def test_setitem_stores_value(data_service: DataService[str, int]): - data_service["key1"] = 42 - assert data_service["key1"] == 42 + import scipp as sc + + data = sc.DataArray( + sc.scalar(42, unit='counts'), coords={'time': sc.scalar(0.0, unit='s')} + ) + data_service["key1"] = data + retrieved = data_service["key1"] + assert retrieved.value == 42 assert "key1" in data_service def test_setitem_without_subscribers_no_error(data_service: DataService[str, int]): - data_service["key1"] = 42 - assert data_service["key1"] == 42 + import scipp as sc + + data = sc.DataArray( + sc.scalar(42, unit='counts'), coords={'time': sc.scalar(0.0, unit='s')} + ) + data_service["key1"] = data + retrieved = data_service["key1"] + assert retrieved.value == 42 def test_register_subscriber_adds_to_list(data_service: DataService[str, int]): - subscriber, _ = create_test_subscriber({"key1"}) + subscriber, get_pipe = create_test_subscriber({"key1"}) data_service.register_subscriber(subscriber) + # Verify pipe was created and subscriber was added + _ = get_pipe() # Ensure pipe exists def test_setitem_notifies_matching_subscriber(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) - data_service["key1"] = 42 + data_service["key1"] = make_test_data(42) + pipe = get_pipe() assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key1": 42} + assert pipe.sent_data[0]["key1"].value == 42 def test_setitem_ignores_non_matching_subscriber(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"other_key"}) + subscriber, get_pipe = create_test_subscriber({"other_key"}) data_service.register_subscriber(subscriber) - data_service["key1"] = 42 + data_service["key1"] = make_test_data(42) + pipe = get_pipe() assert len(pipe.sent_data) == 0 def test_setitem_notifies_multiple_matching_subscribers( data_service: DataService[str, int], ): - subscriber1, pipe1 = create_test_subscriber({"key1"}) - subscriber2, pipe2 = create_test_subscriber({"key1", "key2"}) - subscriber3, pipe3 = create_test_subscriber({"key2"}) + subscriber1, get_pipe1 = create_test_subscriber({"key1"}) + subscriber2, get_pipe2 = create_test_subscriber({"key1", "key2"}) + subscriber3, get_pipe3 = create_test_subscriber({"key2"}) data_service.register_subscriber(subscriber1) data_service.register_subscriber(subscriber2) data_service.register_subscriber(subscriber3) - data_service["key1"] = 42 + data_service["key1"] = make_test_data(42) + pipe1, pipe2, pipe3 = get_pipe1(), get_pipe2(), get_pipe3() assert len(pipe1.sent_data) == 1 assert len(pipe2.sent_data) == 1 assert len(pipe3.sent_data) == 0 @@ -106,98 +148,120 @@ def test_setitem_notifies_multiple_matching_subscribers( def test_setitem_multiple_updates_notify_separately( data_service: DataService[str, int], ): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) - data_service["key1"] = 42 - data_service["key2"] = 84 + data_service["key1"] = make_test_data(42) + data_service["key2"] = make_test_data(84) + pipe = get_pipe() assert len(pipe.sent_data) == 2 - assert pipe.sent_data[0] == {"key1": 42} - assert pipe.sent_data[1] == {"key1": 42, "key2": 84} + assert pipe.sent_data[0]["key1"].value == 42 + assert pipe.sent_data[1]["key1"].value == 42 + assert pipe.sent_data[1]["key2"].value == 84 def test_transaction_batches_notifications(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) + pipe = get_pipe() with data_service.transaction(): - data_service["key1"] = 42 - data_service["key2"] = 84 + data_service["key1"] = make_test_data(42) + data_service["key2"] = make_test_data(84) # No notifications yet assert len(pipe.sent_data) == 0 # Single notification after transaction assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key1": 42, "key2": 84} + assert pipe.sent_data[0]["key1"].value == 42 + assert pipe.sent_data[0]["key2"].value == 84 def test_transaction_nested_batches_correctly(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2", "key3"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2", "key3"}) data_service.register_subscriber(subscriber) + pipe = get_pipe() with data_service.transaction(): - data_service["key1"] = 42 + data_service["key1"] = make_test_data(42) with data_service.transaction(): - data_service["key2"] = 84 + data_service["key2"] = make_test_data(84) assert len(pipe.sent_data) == 0 # Still in outer transaction assert len(pipe.sent_data) == 0 - data_service["key3"] = 126 + data_service["key3"] = make_test_data(126) # Single notification after all transactions assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key1": 42, "key2": 84, "key3": 126} + assert pipe.sent_data[0]["key1"].value == 42 + assert pipe.sent_data[0]["key2"].value == 84 + assert pipe.sent_data[0]["key3"].value == 126 def test_transaction_exception_still_notifies(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1"}) + subscriber, get_pipe = create_test_subscriber({"key1"}) data_service.register_subscriber(subscriber) try: with data_service.transaction(): - data_service["key1"] = 42 + data_service["key1"] = make_test_data(42) raise ValueError("test error") except ValueError: # Exception should not prevent notification pass # Notification should still happen + pipe = get_pipe() assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key1": 42} + assert pipe.sent_data[0]["key1"].value == 42 def test_dictionary_operations_work( - data_service: DataService[str, int], sample_data: dict[str, int] + data_service: DataService[str, int], ): - # Test basic dict operations - for key, value in sample_data.items(): + # Test basic dict operations with proper scipp data + test_data = { + "key1": make_test_data(100), + "key2": make_test_data(200), + "key3": make_test_data(300), + } + for key, value in test_data.items(): data_service[key] = value assert len(data_service) == 3 - assert data_service["key1"] == 100 - assert data_service.get("key1") == 100 - assert data_service.get("nonexistent", 999) == 999 + assert data_service["key1"].value == 100 + assert data_service.get("key1").value == 100 + assert data_service.get("nonexistent") is None assert "key1" in data_service assert "nonexistent" not in data_service assert list(data_service.keys()) == ["key1", "key2", "key3"] - assert list(data_service.values()) == [100, 200, 300] + assert all( + v.value == v_expected + for v, v_expected in zip(data_service.values(), [100, 200, 300], strict=True) + ) def test_update_method_triggers_notifications(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) - data_service.update({"key1": 42, "key2": 84}) + data_service.update({"key1": make_test_data(42), "key2": make_test_data(84)}) # Should trigger notifications for each key + pipe = get_pipe() assert len(pipe.sent_data) == 2 def test_clear_removes_all_data( - data_service: DataService[str, int], sample_data: dict[str, int] + data_service: DataService[str, int], ): - data_service.update(sample_data) + test_data = { + "key1": make_test_data(100), + "key2": make_test_data(200), + "key3": make_test_data(300), + } + data_service.update(test_data) assert len(data_service) == 3 data_service.clear() @@ -205,59 +269,64 @@ def test_clear_removes_all_data( def test_pop_removes_and_returns_value(data_service: DataService[str, int]): - data_service["key1"] = 42 + data_service["key1"] = make_test_data(42) value = data_service.pop("key1") - assert value == 42 + assert value.value == 42 assert "key1" not in data_service def test_setdefault_behavior(data_service: DataService[str, int]): - value = data_service.setdefault("key1", 42) - assert value == 42 - assert data_service["key1"] == 42 + test_val = make_test_data(42) + value = data_service.setdefault("key1", test_val) + assert value.value == 42 + assert data_service["key1"].value == 42 # Second call should return existing value - value = data_service.setdefault("key1", 999) - assert value == 42 - assert data_service["key1"] == 42 + value = data_service.setdefault("key1", make_test_data(999)) + assert value.value == 42 + assert data_service["key1"].value == 42 def test_subscriber_gets_full_data_dict(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1"}) + subscriber, get_pipe = create_test_subscriber({"key1"}) data_service.register_subscriber(subscriber) # Add some initial data - data_service["existing"] = 999 - data_service["key1"] = 42 + data_service["existing"] = make_test_data(999) + data_service["key1"] = make_test_data(42) # Subscriber should get the full data dict - assert pipe.sent_data[-1] == {"key1": 42} + pipe = get_pipe() + assert pipe.sent_data[-1]["key1"].value == 42 def test_subscriber_only_gets_subscribed_keys(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key3"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key3"}) data_service.register_subscriber(subscriber) # Add data for subscribed and unsubscribed keys - data_service["key1"] = 42 - data_service["key2"] = 84 # Not subscribed to this key - data_service["key3"] = 126 - data_service["unrelated"] = 999 # Not subscribed to this key + data_service["key1"] = make_test_data(42) + data_service["key2"] = make_test_data(84) # Not subscribed to this key + data_service["key3"] = make_test_data(126) + data_service["unrelated"] = make_test_data(999) # Not subscribed to this key # Subscriber should only receive data for keys it's interested in - expected_data = {"key1": 42, "key3": 126} - assert pipe.sent_data[-1] == expected_data + pipe = get_pipe() + last_data = pipe.sent_data[-1] + assert last_data["key1"].value == 42 + assert last_data["key3"].value == 126 # Verify unrelated keys are not included - assert "key2" not in pipe.sent_data[-1] - assert "unrelated" not in pipe.sent_data[-1] + assert "key2" not in last_data + assert "unrelated" not in last_data def test_empty_transaction_no_notifications(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1"}) + subscriber, get_pipe = create_test_subscriber({"key1"}) data_service.register_subscriber(subscriber) + pipe = get_pipe() with data_service.transaction(): pass # No changes @@ -265,12 +334,13 @@ def test_empty_transaction_no_notifications(data_service: DataService[str, int]) def test_delitem_notifies_subscribers(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) # Add some data first - data_service["key1"] = 42 - data_service["key2"] = 84 + data_service["key1"] = make_test_data(42) + data_service["key2"] = make_test_data(84) + pipe = get_pipe() pipe.sent_data.clear() # Clear previous notifications # Delete a key @@ -278,93 +348,98 @@ def test_delitem_notifies_subscribers(data_service: DataService[str, int]): # Should notify with remaining data assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key2": 84} + assert pipe.sent_data[0]["key2"].value == 84 assert "key1" not in data_service def test_delitem_in_transaction_batches_notifications( data_service: DataService[str, int], ): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) # Add some data first - data_service["key1"] = 42 - data_service["key2"] = 84 + data_service["key1"] = make_test_data(42) + data_service["key2"] = make_test_data(84) + pipe = get_pipe() pipe.sent_data.clear() # Clear previous notifications with data_service.transaction(): del data_service["key1"] - data_service["key2"] = 99 + data_service["key2"] = make_test_data(99) # No notifications yet assert len(pipe.sent_data) == 0 # Single notification after transaction assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key2": 99} + assert pipe.sent_data[0]["key2"].value == 99 def test_transaction_set_then_del_same_key(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) # Add some initial data - data_service["key2"] = 84 + data_service["key2"] = make_test_data(84) + pipe = get_pipe() pipe.sent_data.clear() with data_service.transaction(): - data_service["key1"] = 42 # Set key1 + data_service["key1"] = make_test_data(42) # Set key1 del data_service["key1"] # Then delete key1 # No notifications yet assert len(pipe.sent_data) == 0 # After transaction: key1 should not exist, only key2 should be in notification assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key2": 84} + assert pipe.sent_data[0]["key2"].value == 84 assert "key1" not in data_service def test_transaction_del_then_set_same_key(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1", "key2"}) + subscriber, get_pipe = create_test_subscriber({"key1", "key2"}) data_service.register_subscriber(subscriber) # Add some initial data - data_service["key1"] = 42 - data_service["key2"] = 84 + data_service["key1"] = make_test_data(42) + data_service["key2"] = make_test_data(84) + pipe = get_pipe() pipe.sent_data.clear() with data_service.transaction(): del data_service["key1"] # Delete key1 - data_service["key1"] = 99 # Then set key1 to new value + data_service["key1"] = make_test_data(99) # Then set key1 to new value # No notifications yet assert len(pipe.sent_data) == 0 # After transaction: key1 should have the new value assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key1": 99, "key2": 84} - assert data_service["key1"] == 99 + assert pipe.sent_data[0]["key1"].value == 99 + assert pipe.sent_data[0]["key2"].value == 84 + assert data_service["key1"].value == 99 def test_transaction_multiple_operations_same_key(data_service: DataService[str, int]): - subscriber, pipe = create_test_subscriber({"key1"}) + subscriber, get_pipe = create_test_subscriber({"key1"}) data_service.register_subscriber(subscriber) # Add initial data - data_service["key1"] = 10 + data_service["key1"] = make_test_data(10) + pipe = get_pipe() pipe.sent_data.clear() with data_service.transaction(): - data_service["key1"] = 20 # Update - data_service["key1"] = 30 # Update again + data_service["key1"] = make_test_data(20) # Update + data_service["key1"] = make_test_data(30) # Update again del data_service["key1"] # Delete - data_service["key1"] = 40 # Set again + data_service["key1"] = make_test_data(40) # Set again # No notifications yet assert len(pipe.sent_data) == 0 # After transaction: key1 should have final value assert len(pipe.sent_data) == 1 - assert pipe.sent_data[0] == {"key1": 40} - assert data_service["key1"] == 40 + assert pipe.sent_data[0]["key1"].value == 40 + assert data_service["key1"].value == 40 class TestDataServiceUpdatingSubscribers: @@ -374,142 +449,191 @@ def test_subscriber_updates_service_immediately(self): """Test subscriber updating service outside of transaction.""" service = DataService[str, int]() - class UpdatingSubscriber(DataSubscriber[str]): + class UpdatingSubscriber(DataServiceSubscriber[str]): def __init__(self, keys: set[str], service: DataService[str, int]): - super().__init__(FakeDataAssembler(keys), FakePipe()) + self._keys_set = keys + self._extractors = {key: LatestValueExtractor() for key in keys} self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) - # Update derived data based on received data if "input" in store: - self._service["derived"] = store["input"] * 2 + derived_value = store["input"].value * 2 + self._service["derived"] = make_test_data(derived_value) subscriber = UpdatingSubscriber({"input"}, service) service.register_subscriber(subscriber) # This should trigger the subscriber, which updates "derived" - service["input"] = 10 + service["input"] = make_test_data(10) - assert service["input"] == 10 - assert service["derived"] == 20 + assert service["input"].value == 10 + assert service["derived"].value == 20 def test_subscriber_updates_service_in_transaction(self): """Test subscriber updating service at end of transaction.""" service = DataService[str, int]() - class UpdatingSubscriber(DataSubscriber[str]): + class UpdatingSubscriber(DataServiceSubscriber[str]): def __init__(self, keys: set[str], service: DataService[str, int]): - super().__init__(FakeDataAssembler(keys), FakePipe()) + self._keys_set = keys + self._extractors = {key: LatestValueExtractor() for key in keys} self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: - self._service["derived"] = store["input"] * 2 + derived_value = store["input"].value * 2 + self._service["derived"] = make_test_data(derived_value) subscriber = UpdatingSubscriber({"input"}, service) service.register_subscriber(subscriber) with service.transaction(): - service["input"] = 10 + service["input"] = make_test_data(10) # "derived" should not exist yet during transaction assert "derived" not in service # After transaction, both keys should exist - assert service["input"] == 10 - assert service["derived"] == 20 + assert service["input"].value == 10 + assert service["derived"].value == 20 def test_multiple_subscribers_update_service(self): """Test multiple subscribers updating different derived data.""" service = DataService[str, int]() - class MultiplierSubscriber(DataSubscriber[str]): + class MultiplierSubscriber(DataServiceSubscriber[str]): def __init__( self, keys: set[str], service: DataService[str, int], multiplier: int, ): - super().__init__(FakeDataAssembler(keys), FakePipe()) + self._keys_set = keys + self._extractors = {key: LatestValueExtractor() for key in keys} self._service = service self._multiplier = multiplier + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: key = f"derived_{self._multiplier}x" - self._service[key] = store["input"] * self._multiplier + derived_value = store["input"].value * self._multiplier + self._service[key] = make_test_data(derived_value) sub1 = MultiplierSubscriber({"input"}, service, 2) sub2 = MultiplierSubscriber({"input"}, service, 3) service.register_subscriber(sub1) service.register_subscriber(sub2) - service["input"] = 10 + service["input"] = make_test_data(10) - assert service["input"] == 10 - assert service["derived_2x"] == 20 - assert service["derived_3x"] == 30 + assert service["input"].value == 10 + assert service["derived_2x"].value == 20 + assert service["derived_3x"].value == 30 def test_cascading_subscriber_updates(self): """Test subscribers that depend on derived data from other subscribers.""" service = DataService[str, int]() - class FirstLevelSubscriber(DataSubscriber[str]): + class FirstLevelSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input"}), FakePipe()) + self._keys_set = {"input"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: - self._service["level1"] = store["input"] * 2 + derived_value = store["input"].value * 2 + self._service["level1"] = make_test_data(derived_value) - class SecondLevelSubscriber(DataSubscriber[str]): + class SecondLevelSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"level1"}), FakePipe()) + self._keys_set = {"level1"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "level1" in store: - self._service["level2"] = store["level1"] * 3 + derived_value = store["level1"].value * 3 + self._service["level2"] = make_test_data(derived_value) sub1 = FirstLevelSubscriber(service) sub2 = SecondLevelSubscriber(service) service.register_subscriber(sub1) service.register_subscriber(sub2) - service["input"] = 5 + service["input"] = make_test_data(5) - assert service["input"] == 5 - assert service["level1"] == 10 - assert service["level2"] == 30 + assert service["input"].value == 5 + assert service["level1"].value == 10 + assert service["level2"].value == 30 def test_cascading_updates_in_transaction(self): """Test cascading updates within a transaction.""" service = DataService[str, int]() - class FirstLevelSubscriber(DataSubscriber[str]): + class FirstLevelSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input"}), FakePipe()) + self._keys_set = {"input"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: - self._service["level1"] = store["input"] * 2 + derived_value = store["input"].value * 2 + self._service["level1"] = make_test_data(derived_value) - class SecondLevelSubscriber(DataSubscriber[str]): + class SecondLevelSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"level1"}), FakePipe()) + self._keys_set = {"level1"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "level1" in store: - self._service["level2"] = store["level1"] * 3 + derived_value = store["level1"].value * 3 + self._service["level2"] = make_test_data(derived_value) sub1 = FirstLevelSubscriber(service) sub2 = SecondLevelSubscriber(service) @@ -517,175 +641,221 @@ def trigger(self, store: dict[str, int]) -> None: service.register_subscriber(sub2) with service.transaction(): - service["input"] = 5 - service["other"] = 100 + service["input"] = make_test_data(5) + service["other"] = make_test_data(100) # No derived data should exist during transaction assert "level1" not in service assert "level2" not in service # All data should exist after transaction - assert service["input"] == 5 - assert service["other"] == 100 - assert service["level1"] == 10 - assert service["level2"] == 30 + assert service["input"].value == 5 + assert service["other"].value == 100 + assert service["level1"].value == 10 + assert service["level2"].value == 30 def test_subscriber_updates_multiple_keys(self): """Test subscriber that updates multiple derived keys at once.""" service = DataService[str, int]() - class MultiUpdateSubscriber(DataSubscriber[str]): + class MultiUpdateSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input"}), FakePipe()) + self._keys_set = {"input"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: - # Update multiple derived values + input_value = store["input"].value with self._service.transaction(): - self._service["double"] = store["input"] * 2 - self._service["triple"] = store["input"] * 3 - self._service["square"] = store["input"] ** 2 + self._service["double"] = make_test_data(input_value * 2) + self._service["triple"] = make_test_data(input_value * 3) + self._service["square"] = make_test_data(input_value**2) subscriber = MultiUpdateSubscriber(service) service.register_subscriber(subscriber) - service["input"] = 4 + service["input"] = make_test_data(4) - assert service["input"] == 4 - assert service["double"] == 8 - assert service["triple"] == 12 - assert service["square"] == 16 + assert service["input"].value == 4 + assert service["double"].value == 8 + assert service["triple"].value == 12 + assert service["square"].value == 16 def test_subscriber_updates_existing_keys(self): """Test subscriber updating keys that already exist.""" service = DataService[str, int]() - service["existing"] = 100 + service["existing"] = make_test_data(100) - class OverwriteSubscriber(DataSubscriber[str]): + class OverwriteSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input"}), FakePipe()) + self._keys_set = {"input"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: - self._service["existing"] = store["input"] * 10 + derived_value = store["input"].value * 10 + self._service["existing"] = make_test_data(derived_value) subscriber = OverwriteSubscriber(service) service.register_subscriber(subscriber) - service["input"] = 5 + service["input"] = make_test_data(5) - assert service["input"] == 5 - assert service["existing"] == 50 # Overwritten, not 100 + assert service["input"].value == 5 + assert service["existing"].value == 50 # Overwritten, not 100 def test_circular_dependency_protection(self): """Test handling of potential circular dependencies.""" service = DataService[str, int]() update_count = {"count": 0} - class CircularSubscriber(DataSubscriber[str]): + class CircularSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input", "output"}), FakePipe()) + self._keys_set = {"input", "output"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) update_count["count"] += 1 if update_count["count"] < 5: # Prevent infinite recursion in test if "input" in store and "output" not in store: - self._service["output"] = store["input"] + 1 - elif "output" in store and store["output"] < 10: - self._service["output"] = store["output"] + 1 + derived_value = store["input"].value + 1 + self._service["output"] = make_test_data(derived_value) + elif "output" in store and store["output"].value < 10: + derived_value = store["output"].value + 1 + self._service["output"] = make_test_data(derived_value) subscriber = CircularSubscriber(service) service.register_subscriber(subscriber) - service["input"] = 1 + service["input"] = make_test_data(1) # Should handle the circular updates gracefully - assert service["input"] == 1 + assert service["input"].value == 1 assert "output" in service assert update_count["count"] > 1 # Multiple updates occurred def test_subscriber_deletes_keys_during_update(self): """Test subscriber that deletes keys during notification.""" service = DataService[str, int]() - service["to_delete"] = 999 + service["to_delete"] = make_test_data(999) - class DeletingSubscriber(DataSubscriber[str]): + class DeletingSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"trigger"}), FakePipe()) + self._keys_set = {"trigger"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "trigger" in store and "to_delete" in self._service: del self._service["to_delete"] - self._service["deleted_flag"] = 1 + self._service["deleted_flag"] = make_test_data(1) subscriber = DeletingSubscriber(service) service.register_subscriber(subscriber) - service["trigger"] = 1 + service["trigger"] = make_test_data(1) - assert service["trigger"] == 1 + assert service["trigger"].value == 1 assert "to_delete" not in service - assert service["deleted_flag"] == 1 + assert service["deleted_flag"].value == 1 def test_subscriber_complex_transaction_updates(self): """Test complex scenario with nested transactions and subscriber updates.""" service = DataService[str, int]() - class ComplexSubscriber(DataSubscriber[str]): + class ComplexSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input"}), FakePipe()) + self._keys_set = {"input"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: - # Subscriber uses its own transaction + input_value = store["input"].value with self._service.transaction(): - self._service["derived1"] = store["input"] * 2 + self._service["derived1"] = make_test_data(input_value * 2) with self._service.transaction(): - self._service["derived2"] = store["input"] * 3 - self._service["derived3"] = store["input"] * 4 + self._service["derived2"] = make_test_data(input_value * 3) + self._service["derived3"] = make_test_data(input_value * 4) subscriber = ComplexSubscriber(service) service.register_subscriber(subscriber) with service.transaction(): - service["input"] = 5 - service["other"] = 100 + service["input"] = make_test_data(5) + service["other"] = make_test_data(100) # No derived data during transaction assert "derived1" not in service # All data should exist after transaction - assert service["input"] == 5 - assert service["other"] == 100 - assert service["derived1"] == 10 - assert service["derived2"] == 15 - assert service["derived3"] == 20 + assert service["input"].value == 5 + assert service["other"].value == 100 + assert service["derived1"].value == 10 + assert service["derived2"].value == 15 + assert service["derived3"].value == 20 def test_multiple_update_rounds(self): """Test scenario requiring multiple notification rounds.""" service = DataService[str, int]() - class ChainSubscriber(DataSubscriber[str]): + class ChainSubscriber(DataServiceSubscriber[str]): def __init__( self, input_key: str, output_key: str, service: DataService[str, int] ): - super().__init__(FakeDataAssembler({input_key}), FakePipe()) + self._keys_set = {input_key} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._input_key = input_key self._output_key = output_key self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if self._input_key in store: - self._service[self._output_key] = store[self._input_key] + 1 + derived_value = store[self._input_key].value + 1 + self._service[self._output_key] = make_test_data(derived_value) # Create a chain: input -> step1 -> step2 -> step3 sub1 = ChainSubscriber("input", "step1", service) @@ -696,38 +866,176 @@ def trigger(self, store: dict[str, int]) -> None: service.register_subscriber(sub2) service.register_subscriber(sub3) - service["input"] = 10 + service["input"] = make_test_data(10) - assert service["input"] == 10 - assert service["step1"] == 11 - assert service["step2"] == 12 - assert service["step3"] == 13 + assert service["input"].value == 10 + assert service["step1"].value == 11 + assert service["step2"].value == 12 + assert service["step3"].value == 13 def test_subscriber_updates_with_mixed_immediate_and_transaction(self): """Test mixing immediate updates with transactional updates from subscribers.""" service = DataService[str, int]() - class MixedSubscriber(DataSubscriber[str]): + class MixedSubscriber(DataServiceSubscriber[str]): def __init__(self, service: DataService[str, int]): - super().__init__(FakeDataAssembler({"input"}), FakePipe()) + self._keys_set = {"input"} + self._extractors = { + key: LatestValueExtractor() for key in self._keys_set + } self._service = service + super().__init__() + + @property + def extractors(self): + return self._extractors def trigger(self, store: dict[str, int]) -> None: - super().trigger(store) if "input" in store: + input_value = store["input"].value # Immediate update - self._service["immediate"] = store["input"] * 2 + self._service["immediate"] = make_test_data(input_value * 2) # Transaction update with self._service.transaction(): - self._service["transactional1"] = store["input"] * 3 - self._service["transactional2"] = store["input"] * 4 + self._service["transactional1"] = make_test_data( + input_value * 3 + ) + self._service["transactional2"] = make_test_data( + input_value * 4 + ) subscriber = MixedSubscriber(service) service.register_subscriber(subscriber) - service["input"] = 5 + service["input"] = make_test_data(5) + + assert service["input"].value == 5 + assert service["immediate"].value == 10 + assert service["transactional1"].value == 15 + assert service["transactional2"].value == 20 + + +# Tests for extractor-based subscription +class TestExtractorBasedSubscription: + """Tests for extractor-based subscription with dynamic buffer sizing.""" + + def test_buffer_size_determined_by_max_extractor_requirement(self): + """Test that buffer size is set to max requirement among subscribers.""" + import scipp as sc + + from ess.livedata.dashboard.data_service import DataService + from ess.livedata.dashboard.extractors import ( + FullHistoryExtractor, + LatestValueExtractor, + ) + + class TestSubscriber(DataServiceSubscriber[str]): + def __init__(self, keys: set[str], extractor): + self._keys_set = keys + self._extractor = extractor + self.received_data: list[dict] = [] + super().__init__() + + @property + def extractors(self) -> dict: + return {key: self._extractor for key in self._keys_set} + + def trigger(self, data: dict) -> None: + self.received_data.append(data) + + # Create service + service = DataService() + + # Register subscriber with LatestValueExtractor (size 1) + sub1 = TestSubscriber({"data"}, LatestValueExtractor()) + service.register_subscriber(sub1) + + # Add first data point - buffer should be size 1 + service["data"] = sc.DataArray( + sc.scalar(1, unit='counts'), coords={'time': sc.scalar(0.0, unit='s')} + ) + + # Register subscriber with FullHistoryExtractor (size 10000) + sub2 = TestSubscriber({"data"}, FullHistoryExtractor()) + service.register_subscriber(sub2) + + # Buffer should now grow to size 10000 + # Add more data to verify buffering works + for i in range(2, 12): + service["data"] = sc.DataArray( + sc.scalar(i, unit='counts'), + coords={'time': sc.scalar(float(i - 1), unit='s')}, + ) + + # Both subscribers should have received all updates + # sub1: 1 initial trigger + 1 update before sub2 registration + 10 after = 12 + assert len(sub1.received_data) == 12 + # sub2: 1 initial trigger on registration (with copied data) + 10 updates = 11 + assert len(sub2.received_data) == 11 + + # sub1 should get latest value only (unwrapped) + last_from_sub1 = sub1.received_data[-1]["data"] + assert last_from_sub1.ndim == 0 # Scalar (unwrapped) + assert last_from_sub1.value == 11 + + # sub2 should get all history after it was registered + # When sub2 registered, buffer switched from SingleValueBuffer to + # TemporalBuffer, copying the first data point, then receiving 10 more + # updates = 11 total + last_from_sub2 = sub2.received_data[-1]["data"] + assert last_from_sub2.sizes == {'time': 11} + + def test_multiple_keys_with_different_extractors(self): + """Test subscriber with different extractors per key.""" + import scipp as sc + + from ess.livedata.dashboard.data_service import DataService + from ess.livedata.dashboard.extractors import ( + FullHistoryExtractor, + LatestValueExtractor, + ) + + class MultiKeySubscriber(DataServiceSubscriber[str]): + def __init__(self): + self.received_data: list[dict] = [] + super().__init__() + + @property + def extractors(self) -> dict: + return { + "latest": LatestValueExtractor(), + "history": FullHistoryExtractor(), + } + + def trigger(self, data: dict) -> None: + self.received_data.append(data) + + service = DataService() + subscriber = MultiKeySubscriber() + service.register_subscriber(subscriber) - assert service["input"] == 5 - assert service["immediate"] == 10 - assert service["transactional1"] == 15 - assert service["transactional2"] == 20 + # Add data to both keys + for i in range(5): + service["latest"] = sc.DataArray( + sc.scalar(i * 10, unit='counts'), + coords={'time': sc.scalar(float(i), unit='s')}, + ) + service["history"] = sc.DataArray( + sc.scalar(i * 100, unit='counts'), + coords={'time': sc.scalar(float(i), unit='s')}, + ) + + # Should have received updates (batched in transaction would be less, + # but here each setitem triggers separately) + assert len(subscriber.received_data) > 0 + + # Check last received data + last_data = subscriber.received_data[-1] + + # "latest" should be unwrapped scalar + if "latest" in last_data: + assert last_data["latest"].ndim == 0 + + # "history" should return all accumulated values with time dimension + if "history" in last_data: + assert "time" in last_data["history"].dims diff --git a/tests/dashboard/data_subscriber_test.py b/tests/dashboard/data_subscriber_test.py index 675b3f32a..22b4c0284 100644 --- a/tests/dashboard/data_subscriber_test.py +++ b/tests/dashboard/data_subscriber_test.py @@ -14,6 +14,7 @@ Pipe, StreamAssembler, ) +from ess.livedata.dashboard.extractors import LatestValueExtractor class FakeStreamAssembler(StreamAssembler[str]): @@ -32,7 +33,8 @@ def assemble(self, data: dict[str, Any]) -> Any: class FakePipe(Pipe): """Fake implementation of Pipe for testing.""" - def __init__(self) -> None: + def __init__(self, data: Any = None) -> None: + self.init_data = data self.send_calls: list[Any] = [] def send(self, data: Any) -> None: @@ -57,25 +59,51 @@ def fake_pipe() -> FakePipe: return FakePipe() +@pytest.fixture +def fake_pipe_factory(): + """Fake pipe factory for testing.""" + + def factory(data: Any) -> FakePipe: + """Factory that creates a new FakePipe with the given data.""" + return FakePipe(data) + + return factory + + +@pytest.fixture +def sample_extractors(sample_keys: set[str]) -> dict[str, LatestValueExtractor]: + """Sample extractors for testing.""" + return {key: LatestValueExtractor() for key in sample_keys} + + @pytest.fixture def subscriber( - fake_assembler: FakeStreamAssembler, fake_pipe: FakePipe + fake_assembler: FakeStreamAssembler, + fake_pipe_factory, + sample_extractors: dict[str, LatestValueExtractor], ) -> DataSubscriber[str]: """DataSubscriber instance for testing.""" - return DataSubscriber(fake_assembler, fake_pipe) + return DataSubscriber(fake_assembler, fake_pipe_factory, sample_extractors) class TestDataSubscriber: """Test cases for DataSubscriber class.""" - def test_init_stores_assembler_and_pipe( - self, fake_assembler: FakeStreamAssembler, fake_pipe: FakePipe + def test_init_stores_assembler_and_pipe_factory( + self, + fake_assembler: FakeStreamAssembler, + fake_pipe_factory, + sample_extractors: dict[str, LatestValueExtractor], ) -> None: - """Test that initialization stores the assembler and pipe correctly.""" - subscriber = DataSubscriber(fake_assembler, fake_pipe) + """Test that initialization stores the assembler and pipe factory correctly.""" + subscriber = DataSubscriber( + fake_assembler, fake_pipe_factory, sample_extractors + ) assert subscriber._assembler is fake_assembler - assert subscriber._pipe is fake_pipe + assert subscriber._pipe_factory is fake_pipe_factory + assert subscriber._pipe is None # Pipe not yet created + assert subscriber._extractors is sample_extractors def test_keys_returns_assembler_keys( self, subscriber: DataSubscriber, sample_keys: set[str] @@ -83,11 +111,27 @@ def test_keys_returns_assembler_keys( """Test that keys property returns the assembler's keys.""" assert subscriber.keys == sample_keys + def test_pipe_created_on_first_trigger( + self, + subscriber: DataSubscriber, + ) -> None: + """Test that pipe is created on first trigger.""" + # Before trigger, accessing pipe raises error + with pytest.raises(RuntimeError, match="not yet initialized"): + _ = subscriber.pipe + + # Trigger subscriber + subscriber.trigger({'key1': 'value1'}) + + # After trigger, pipe is accessible and has correct data + pipe = subscriber.pipe + assert isinstance(pipe, FakePipe) + assert pipe.init_data == 'assembled_data' + def test_trigger_with_complete_data( self, subscriber: DataSubscriber, fake_assembler: FakeStreamAssembler, - fake_pipe: FakePipe, ) -> None: """Test trigger method when all required keys are present in store.""" store = { @@ -104,15 +148,15 @@ def test_trigger_with_complete_data( expected_data = {'key1': 'value1', 'key2': 'value2', 'key3': 'value3'} assert fake_assembler.assemble_calls[0] == expected_data - # Verify pipe was called with assembled data - assert len(fake_pipe.send_calls) == 1 - assert fake_pipe.send_calls[0] == 'assembled_data' + # Verify pipe was created with assembled data (first trigger) + pipe = subscriber.pipe + assert pipe.init_data == 'assembled_data' + assert len(pipe.send_calls) == 0 # First trigger creates, doesn't send def test_trigger_with_partial_data( self, subscriber: DataSubscriber, fake_assembler: FakeStreamAssembler, - fake_pipe: FakePipe, ) -> None: """Test trigger method when only some required keys are present in store.""" store = {'key1': 'value1', 'key3': 'value3', 'unrelated_key': 'unrelated_value'} @@ -124,15 +168,15 @@ def test_trigger_with_partial_data( expected_data = {'key1': 'value1', 'key3': 'value3'} assert fake_assembler.assemble_calls[0] == expected_data - # Verify pipe was called - assert len(fake_pipe.send_calls) == 1 - assert fake_pipe.send_calls[0] == 'assembled_data' + # Verify pipe was created with assembled data + pipe = subscriber.pipe + assert pipe.init_data == 'assembled_data' + assert len(pipe.send_calls) == 0 def test_trigger_with_empty_store( self, subscriber: DataSubscriber, fake_assembler: FakeStreamAssembler, - fake_pipe: FakePipe, ) -> None: """Test trigger method with an empty store.""" store: dict[str, Any] = {} @@ -143,15 +187,15 @@ def test_trigger_with_empty_store( assert len(fake_assembler.assemble_calls) == 1 assert fake_assembler.assemble_calls[0] == {} - # Verify pipe was called - assert len(fake_pipe.send_calls) == 1 - assert fake_pipe.send_calls[0] == 'assembled_data' + # Verify pipe was created with assembled data + pipe = subscriber.pipe + assert pipe.init_data == 'assembled_data' + assert len(pipe.send_calls) == 0 def test_trigger_with_no_matching_keys( self, subscriber: DataSubscriber, fake_assembler: FakeStreamAssembler, - fake_pipe: FakePipe, ) -> None: """Test trigger method when store contains no matching keys.""" store = {'other_key1': 'value1', 'other_key2': 'value2'} @@ -162,15 +206,15 @@ def test_trigger_with_no_matching_keys( assert len(fake_assembler.assemble_calls) == 1 assert fake_assembler.assemble_calls[0] == {} - # Verify pipe was called - assert len(fake_pipe.send_calls) == 1 - assert fake_pipe.send_calls[0] == 'assembled_data' + # Verify pipe was created with assembled data + pipe = subscriber.pipe + assert pipe.init_data == 'assembled_data' + assert len(pipe.send_calls) == 0 def test_trigger_multiple_calls( self, subscriber: DataSubscriber, fake_assembler: FakeStreamAssembler, - fake_pipe: FakePipe, ) -> None: """Test multiple calls to trigger method.""" store1 = {'key1': 'value1', 'key2': 'value2'} @@ -187,23 +231,29 @@ def test_trigger_multiple_calls( 'key3': 'value3', } - assert len(fake_pipe.send_calls) == 2 - assert all(call == 'assembled_data' for call in fake_pipe.send_calls) + # First call creates pipe, second call sends + pipe = subscriber.pipe + assert pipe.init_data == 'assembled_data' + assert len(pipe.send_calls) == 1 + assert pipe.send_calls[0] == 'assembled_data' def test_trigger_with_different_assembled_data(self, sample_keys: set[str]) -> None: """Test trigger method with assembler that returns different data types.""" assembled_values = [42, {'result': 'success'}, [1, 2, 3], None] + extractors = {key: LatestValueExtractor() for key in sample_keys} for value in assembled_values: assembler = FakeStreamAssembler(sample_keys, value) - pipe = FakePipe() - subscriber = DataSubscriber(assembler, pipe) + pipe_factory = lambda data: FakePipe(data) # noqa: E731 + subscriber = DataSubscriber(assembler, pipe_factory, extractors) store = {'key1': 'test_value'} subscriber.trigger(store) - assert len(pipe.send_calls) == 1 - assert pipe.send_calls[0] == value + # First trigger creates pipe with data + pipe = subscriber.pipe + assert pipe.init_data == value + assert len(pipe.send_calls) == 0 class TestMergingStreamAssembler: diff --git a/tests/dashboard/extractors_test.py b/tests/dashboard/extractors_test.py new file mode 100644 index 000000000..0f3a90b16 --- /dev/null +++ b/tests/dashboard/extractors_test.py @@ -0,0 +1,588 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import pytest +import scipp as sc + +from ess.livedata.dashboard.extractors import ( + FullHistoryExtractor, + LatestValueExtractor, + UpdateExtractor, + WindowAggregatingExtractor, +) +from ess.livedata.dashboard.plot_params import WindowAggregation + + +class TestLatestValueExtractor: + """Tests for LatestValueExtractor.""" + + def test_get_required_timespan_returns_zero(self): + """Latest value extractor requires zero history.""" + extractor = LatestValueExtractor() + assert extractor.get_required_timespan() == 0.0 + + def test_extract_latest_value_from_concatenated_data(self): + """Extract the latest value from data with concat dimension.""" + extractor = LatestValueExtractor(concat_dim='time') + + # Create data with time dimension + data = sc.DataArray( + sc.array( + dims=['time', 'x'], values=[[1, 2], [3, 4], [5, 6]], unit='counts' + ), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0, 2.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Should extract only the last time slice + assert 'time' not in result.dims + assert sc.identical(result, data['time', -1]) + + def test_extract_from_data_without_concat_dimension(self): + """Extract from data that doesn't have concat dimension (single frame).""" + extractor = LatestValueExtractor(concat_dim='time') + + # Create data without time dimension + data = sc.DataArray( + sc.array(dims=['x'], values=[1, 2, 3], unit='counts'), + coords={'x': sc.arange('x', 3, unit='m')}, + ) + + result = extractor.extract(data) + + # Should return data as-is + assert sc.identical(result, data) + + def test_extract_with_custom_concat_dim(self): + """Test extraction with custom concat dimension name.""" + extractor = LatestValueExtractor(concat_dim='event') + + data = sc.DataArray( + sc.array(dims=['event', 'x'], values=[[1, 2], [3, 4]], unit='counts'), + coords={'event': sc.arange('event', 2)}, + ) + + result = extractor.extract(data) + + assert 'event' not in result.dims + assert sc.identical(result, data['event', -1]) + + def test_extract_scalar_data(self): + """Extract from scalar data.""" + extractor = LatestValueExtractor() + data = sc.scalar(42.0, unit='counts') + + result = extractor.extract(data) + + assert sc.identical(result, data) + + +class TestFullHistoryExtractor: + """Tests for FullHistoryExtractor.""" + + def test_get_required_timespan_returns_infinity(self): + """Full history extractor requires infinite timespan.""" + extractor = FullHistoryExtractor() + assert extractor.get_required_timespan() == float('inf') + + def test_extract_returns_all_data(self): + """Extract returns complete buffer history.""" + extractor = FullHistoryExtractor() + + data = sc.DataArray( + sc.array( + dims=['time', 'x'], values=[[1, 2], [3, 4], [5, 6]], unit='counts' + ), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0, 2.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Should return all data unchanged + assert sc.identical(result, data) + + def test_extract_with_multidimensional_data(self): + """Extract with complex multidimensional data.""" + extractor = FullHistoryExtractor() + + data = sc.DataArray( + sc.array( + dims=['time', 'y', 'x'], + values=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + unit='counts', + ), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'y': sc.arange('y', 2), + 'x': sc.arange('x', 2), + }, + ) + + result = extractor.extract(data) + + assert sc.identical(result, data) + + +class TestWindowAggregatingExtractor: + """Tests for WindowAggregatingExtractor.""" + + def test_get_required_timespan(self): + """Test that get_required_timespan returns the window duration.""" + extractor = WindowAggregatingExtractor(window_duration_seconds=5.0) + assert extractor.get_required_timespan() == 5.0 + + def test_extract_window_and_aggregate_with_nansum(self): + """Extract a window of data and aggregate using nansum.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=1.5, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Create data spanning 3 seconds + data = sc.DataArray( + sc.array( + dims=['time', 'x'], values=[[1, 2], [3, 4], [5, 6]], unit='counts' + ), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0, 2.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window is [2.0 - 1.5, 2.0] = [0.5, 2.0], should include times 1.0 and 2.0 + assert 'time' not in result.dims + # nansum of [3, 4] and [5, 6] = [8, 10] + expected = sc.DataArray( + sc.array(dims=['x'], values=[8, 10], unit='counts'), + coords={'x': sc.arange('x', 2, unit='m')}, + ) + assert sc.allclose(result.data, expected.data) + + def test_extract_window_and_aggregate_with_nanmean(self): + """Extract a window of data and aggregate using nanmean.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=1.5, + aggregation=WindowAggregation.nanmean, + concat_dim='time', + ) + + data = sc.DataArray( + sc.array( + dims=['time', 'x'], values=[[1, 2], [3, 4], [5, 6]], unit='counts' + ), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0, 2.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window includes times 1.0 and 2.0 + # nanmean of [3, 4] and [5, 6] = [4, 5] + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[4.0, 5.0], unit='counts') + ) + + def test_extract_window_and_aggregate_with_sum(self): + """Extract a window of data and aggregate using sum.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=2.0, + aggregation=WindowAggregation.sum, + concat_dim='time', + ) + + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[1, 2], [3, 4]], unit='counts'), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window includes all data + # sum of [1, 2] and [3, 4] = [4, 6] + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[4, 6], unit='counts') + ) + + def test_extract_window_and_aggregate_with_mean(self): + """Extract a window of data and aggregate using mean.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=2.0, + aggregation=WindowAggregation.mean, + concat_dim='time', + ) + + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[2, 4], [4, 6]], unit='m'), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # mean of [2, 4] and [4, 6] = [3, 5] + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[3.0, 5.0], unit='m') + ) + + def test_auto_aggregation_with_counts_uses_nansum(self): + """Test that auto aggregation uses nansum for counts unit.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=2.0, + aggregation=WindowAggregation.auto, + concat_dim='time', + ) + + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[1, 2], [3, 4]], unit='counts'), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Should use nansum for counts: [1, 2] + [3, 4] = [4, 6] + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[4, 6], unit='counts') + ) + + def test_auto_aggregation_with_non_counts_uses_nanmean(self): + """Test that auto aggregation uses nanmean for non-counts unit.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=2.0, + aggregation=WindowAggregation.auto, + concat_dim='time', + ) + + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[2, 4], [4, 6]], unit='m'), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Should use nanmean for non-counts: mean([2, 4], [4, 6]) = [3, 5] + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[3.0, 5.0], unit='m') + ) + + def test_extract_is_consistent_across_calls(self): + """Test that extraction produces consistent results across multiple calls.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=1.0, + aggregation=WindowAggregation.nansum, + ) + + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[1, 2]], unit='counts'), + coords={'time': sc.array(dims=['time'], values=[0.0], unit='s')}, + ) + + # Extract twice and verify results are identical + result1 = extractor.extract(data) + result2 = extractor.extract(data) + + assert sc.identical(result1, result2) + + def test_extract_with_different_time_units(self): + """Test extraction with time in milliseconds.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=1.5, # 1.5 seconds = 1500 ms + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Create data with time in milliseconds + data = sc.DataArray( + sc.array( + dims=['time', 'x'], values=[[1, 2], [3, 4], [5, 6]], unit='counts' + ), + coords={ + 'time': sc.array(dims=['time'], values=[0, 1000, 2000], unit='ms'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window is [2000 - 1500, 2000] = [500, 2000] ms + # Should include times 1000 and 2000 + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[8, 10], unit='counts') + ) + + def test_extract_with_custom_concat_dim(self): + """Test extraction with custom concat dimension name.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=0.5, + aggregation=WindowAggregation.nansum, + concat_dim='event', + ) + + data = sc.DataArray( + sc.array(dims=['event', 'x'], values=[[1, 2], [3, 4]], unit='counts'), + coords={ + 'event': sc.array(dims=['event'], values=[0.0, 0.3], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window is [0.3 - 0.5, 0.3] but bounded by data, includes all + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[4, 6], unit='counts') + ) + + def test_all_aggregation_methods_work(self): + """Test that all valid aggregation methods complete without error.""" + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[1, 2], [3, 4]], unit='counts'), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + # Test all aggregation methods + for agg in [ + WindowAggregation.sum, + WindowAggregation.nansum, + WindowAggregation.mean, + WindowAggregation.nanmean, + WindowAggregation.auto, + ]: + extractor = WindowAggregatingExtractor( + window_duration_seconds=2.0, aggregation=agg, concat_dim='time' + ) + result = extractor.extract(data) + # Verify extraction succeeds and time dimension is removed + assert 'time' not in result.dims + + def test_extract_narrow_window(self): + """Test extraction with very narrow window (may include only last point).""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=0.1, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + data = sc.DataArray( + sc.array( + dims=['time', 'x'], values=[[1, 2], [3, 4], [5, 6]], unit='counts' + ), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0, 2.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window is [2.0 - 0.1, 2.0] = [1.9, 2.0], should only include time 2.0 + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[5, 6], unit='counts') + ) + + def test_handles_timing_jitter_at_window_start(self): + """Test that timing noise near window boundary doesn't include extra frames.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=5.0, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Regular 1 Hz data with timing jitter on first frame + # Conceptually frames at t=[0, 1, 2, 3, 4, 5] but first has noise + data = sc.DataArray( + sc.array( + dims=['time', 'x'], + values=[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], + unit='counts', + ), + coords={ + 'time': sc.array( + dims=['time'], values=[0.0001, 1.0, 2.0, 3.0, 4.0, 5.0], unit='s' + ), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window (5-5, 5] = (0, 5] excludes frame at 0.0001 (using exclusive bound) + # Should include 5 frames [1, 2, 3, 4, 5], not all 6 + expected_sum = sc.array(dims=['x'], values=[35, 40], unit='counts') + assert sc.allclose(result.data, expected_sum) + + def test_handles_timing_jitter_at_window_end(self): + """Test that timing noise on latest frame doesn't affect frame count.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=5.0, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Regular 1 Hz data with timing jitter on last frame + data = sc.DataArray( + sc.array( + dims=['time', 'x'], + values=[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], + unit='counts', + ), + coords={ + 'time': sc.array( + dims=['time'], values=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0001], unit='s' + ), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window (5.0001-5, 5.0001] = (0.0001, 5.0001] + # Should include 5 frames [1, 2, 3, 4, 5.0001] + expected_sum = sc.array(dims=['x'], values=[35, 40], unit='counts') + assert sc.allclose(result.data, expected_sum) + + def test_consistent_frame_count_with_perfect_timing(self): + """Test baseline: perfect timing gives expected frame count.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=5.0, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Perfect 1 Hz data at exactly [0, 1, 2, 3, 4, 5] + data = sc.DataArray( + sc.array( + dims=['time', 'x'], + values=[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], + unit='counts', + ), + coords={ + 'time': sc.array( + dims=['time'], values=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], unit='s' + ), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window (0, 5] excludes frame at exactly 0 (exclusive bound) + # Should include 5 frames [1, 2, 3, 4, 5] + expected_sum = sc.array(dims=['x'], values=[35, 40], unit='counts') + assert sc.allclose(result.data, expected_sum) + + def test_extract_with_datetime64_time_coordinate(self): + """Test extraction when time coordinate is datetime64 instead of float.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=2.0, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Create data with datetime64 time coordinate + import numpy as np + + base_time = np.datetime64('2025-01-15T10:00:00', 'ns') + times = base_time + np.array([0, 1, 2, 3], dtype='timedelta64[s]') + + data = sc.DataArray( + sc.array( + dims=['time', 'x'], + values=[[1, 2], [3, 4], [5, 6], [7, 8]], + unit='counts', + ), + coords={ + 'time': sc.array(dims=['time'], values=times, unit='ns'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + result = extractor.extract(data) + + # Window cutoff is latest - 2s + 0.5*median_interval = 3s - 2s + 0.5s = 1.5s + # Should include times >= 1.5s: times at 2s and 3s + # nansum of [5, 6] and [7, 8] = [12, 14] + expected_sum = sc.array(dims=['x'], values=[12, 14], unit='counts') + assert sc.allclose(result.data, expected_sum) + + def test_extract_with_datetime64_single_frame(self): + """Test extraction with datetime64 time coordinate and single frame.""" + extractor = WindowAggregatingExtractor( + window_duration_seconds=1.0, + aggregation=WindowAggregation.nansum, + concat_dim='time', + ) + + # Single frame with datetime64 + import numpy as np + + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[5, 6]], unit='counts'), + coords={ + 'time': sc.array( + dims=['time'], + values=[np.datetime64('2025-01-15T10:00:00', 'ns')], + unit='ns', + ), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + + # This exercises the single-frame path: + # cutoff_time = latest_time - self._duration + result = extractor.extract(data) + + # Should return the single frame + expected = sc.array(dims=['x'], values=[5, 6], unit='counts') + assert sc.allclose(result.data, expected) + + +class TestUpdateExtractorInterface: + """Tests for UpdateExtractor abstract interface.""" + + def test_update_extractor_is_abstract(self): + """Test that UpdateExtractor cannot be instantiated directly.""" + with pytest.raises(TypeError): + UpdateExtractor() # type: ignore[abstract] + + def test_concrete_extractors_implement_interface(self): + """Test that all concrete extractors implement the UpdateExtractor interface.""" + extractors = [ + LatestValueExtractor(), + FullHistoryExtractor(), + WindowAggregatingExtractor(window_duration_seconds=1.0), + ] + + for extractor in extractors: + assert isinstance(extractor, UpdateExtractor) + # Check that required methods are implemented + assert hasattr(extractor, 'extract') + assert hasattr(extractor, 'get_required_timespan') + assert callable(extractor.extract) + assert callable(extractor.get_required_timespan) diff --git a/tests/dashboard/plot_params_test.py b/tests/dashboard/plot_params_test.py new file mode 100644 index 000000000..091aba2ae --- /dev/null +++ b/tests/dashboard/plot_params_test.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from unittest.mock import Mock + +import scipp as sc + +from ess.livedata.dashboard.extractors import ( + FullHistoryExtractor, + LatestValueExtractor, + WindowAggregatingExtractor, +) +from ess.livedata.dashboard.plot_params import ( + WindowAggregation, + WindowMode, + WindowParams, + create_extractors_from_params, +) + + +class TestCreateExtractorsFromParams: + """Tests for create_extractors_from_params factory function.""" + + def test_fallback_to_latest_value_when_no_params(self): + """Test fallback to LatestValueExtractor when no window params provided.""" + keys = ['key1', 'key2'] + + extractors = create_extractors_from_params(keys=keys, window=None, spec=None) + + assert len(extractors) == 2 + assert all(isinstance(ext, LatestValueExtractor) for ext in extractors.values()) + assert set(extractors.keys()) == {'key1', 'key2'} + + def test_create_latest_value_extractors_with_window_mode_latest(self): + """Test creation of LatestValueExtractor when window mode is 'latest'.""" + keys = ['key1'] + window = WindowParams(mode=WindowMode.latest) + + extractors = create_extractors_from_params(keys=keys, window=window, spec=None) + + assert len(extractors) == 1 + assert isinstance(extractors['key1'], LatestValueExtractor) + + def test_create_window_aggregating_extractors_with_window_mode_window(self): + """Test creation of WindowAggregatingExtractor when window mode is 'window'.""" + keys = ['key1', 'key2'] + window = WindowParams( + mode=WindowMode.window, + window_duration_seconds=5.0, + aggregation=WindowAggregation.nansum, + ) + + extractors = create_extractors_from_params(keys=keys, window=window, spec=None) + + assert len(extractors) == 2 + assert all( + isinstance(ext, WindowAggregatingExtractor) for ext in extractors.values() + ) + + # Verify behavior through public interface + extractor = extractors['key1'] + assert extractor.get_required_timespan() == 5.0 + + def test_spec_required_extractor_overrides_window_params(self): + """Test that plotter spec's required extractor overrides window params.""" + keys = ['key1', 'key2'] + window = WindowParams(mode=WindowMode.latest) + + # Create mock spec with required extractor + spec = Mock() + spec.data_requirements.required_extractor = FullHistoryExtractor + + extractors = create_extractors_from_params(keys=keys, window=window, spec=spec) + + # Should use FullHistoryExtractor despite window params + assert len(extractors) == 2 + assert all(isinstance(ext, FullHistoryExtractor) for ext in extractors.values()) + + def test_spec_with_no_required_extractor_uses_window_params(self): + """Test that window params are used when spec has no required extractor.""" + keys = ['key1'] + window = WindowParams(mode=WindowMode.window, window_duration_seconds=3.0) + + # Create mock spec without required extractor + spec = Mock() + spec.data_requirements.required_extractor = None + + extractors = create_extractors_from_params(keys=keys, window=window, spec=spec) + + assert isinstance(extractors['key1'], WindowAggregatingExtractor) + assert extractors['key1'].get_required_timespan() == 3.0 + + def test_creates_extractors_for_all_keys(self): + """Test that extractors are created for all provided keys.""" + keys = ['result1', 'result2', 'result3'] + window = WindowParams(mode=WindowMode.latest) + + extractors = create_extractors_from_params(keys=keys, window=window, spec=None) + + assert len(extractors) == 3 + assert set(extractors.keys()) == {'result1', 'result2', 'result3'} + assert all(isinstance(ext, LatestValueExtractor) for ext in extractors.values()) + + def test_empty_keys_returns_empty_dict(self): + """Test that empty keys list returns empty extractors dict.""" + keys = [] + window = WindowParams(mode=WindowMode.latest) + + extractors = create_extractors_from_params(keys=keys, window=window, spec=None) + + assert extractors == {} + + def test_window_aggregation_parameters_passed_correctly(self): + """Test that window aggregation parameters result in correct behavior.""" + keys = ['key1'] + window = WindowParams( + mode=WindowMode.window, + window_duration_seconds=10.5, + aggregation=WindowAggregation.mean, + ) + + extractors = create_extractors_from_params(keys=keys, window=window, spec=None) + + extractor = extractors['key1'] + assert isinstance(extractor, WindowAggregatingExtractor) + # Verify timespan through public interface + assert extractor.get_required_timespan() == 10.5 + + # Verify aggregation behavior by extracting data + data = sc.DataArray( + sc.array(dims=['time', 'x'], values=[[2, 4], [4, 6]], unit='m'), + coords={ + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + 'x': sc.arange('x', 2, unit='m'), + }, + ) + result = extractor.extract(data) + # Mean of [2, 4] and [4, 6] = [3, 5], verifying mean aggregation was used + assert sc.allclose( + result.data, sc.array(dims=['x'], values=[3.0, 5.0], unit='m') + ) diff --git a/tests/dashboard/roi_detector_plot_factory_test.py b/tests/dashboard/roi_detector_plot_factory_test.py index b73654d4f..cb2be977d 100644 --- a/tests/dashboard/roi_detector_plot_factory_test.py +++ b/tests/dashboard/roi_detector_plot_factory_test.py @@ -156,7 +156,6 @@ def test_create_roi_detector_plot_components_returns_detector_and_spectrum( detector_with_boxes, roi_spectrum, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -192,7 +191,6 @@ def test_create_roi_detector_plot_components_with_only_detector( detector_with_boxes, roi_spectrum, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -235,7 +233,6 @@ def test_create_roi_detector_plot_components_returns_valid_components( detector_dmap, roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -276,7 +273,6 @@ def test_roi_detector_plot_publishes_roi_on_box_edit( _detector_dmap, _roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -326,7 +322,6 @@ def test_roi_detector_plot_only_publishes_changed_rois( _detector_dmap, _roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -371,7 +366,6 @@ def test_roi_detector_plot_without_publisher_does_not_crash( detector_with_boxes, roi_spectrum, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -554,6 +548,9 @@ def test_create_roi_plot_with_initial_rois( 1: RectangleROI(x=Interval(min=7.0, max=9.0), y=Interval(min=4.0, max=6.0)), } + # Populate DataService with detector data + data_service[detector_key] = detector_data + # Inject ROI readback data into DataService - this simulates backend publishing ROIs roi_readback_key = detector_key.model_copy(update={"output_name": "roi_rectangle"}) roi_readback_data = ROI.to_concatenated_data_array(initial_rois) @@ -564,7 +561,6 @@ def test_create_roi_plot_with_initial_rois( _detector_with_boxes, _roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -581,7 +577,9 @@ def test_create_roi_plot_with_initial_rois( assert box_data['y1'][0] == 5.0 -def test_custom_max_roi_count(roi_plot_factory, detector_data, workflow_id, job_number): +def test_custom_max_roi_count( + roi_plot_factory, data_service, detector_data, workflow_id, job_number +): """Test that max_roi_count parameter is correctly applied to BoxEdit.""" detector_key = ResultKey( workflow_id=workflow_id, @@ -589,6 +587,9 @@ def test_custom_max_roi_count(roi_plot_factory, detector_data, workflow_id, job_ output_name='current', ) + # Populate DataService with detector data + data_service[detector_key] = detector_data + # Create params with custom max_roi_count params = PlotParamsROIDetector(plot_scale=PlotScaleParams2d()) params.roi_options.max_roi_count = 5 @@ -596,7 +597,6 @@ def test_custom_max_roi_count(roi_plot_factory, detector_data, workflow_id, job_ _detector_with_boxes, _roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -607,7 +607,7 @@ def test_custom_max_roi_count(roi_plot_factory, detector_data, workflow_id, job_ def test_stale_readback_filtering( - roi_plot_factory, detector_data, workflow_id, job_number + roi_plot_factory, data_service, detector_data, workflow_id, job_number ): """Test that backend is the source of truth for ROI state (state-based sync).""" from ess.livedata.dashboard.roi_publisher import FakeROIPublisher @@ -622,12 +622,14 @@ def test_stale_readback_filtering( output_name='current', ) + # Populate DataService with detector data + data_service[detector_key] = detector_data + params = PlotParamsROIDetector(plot_scale=PlotScaleParams2d()) _detector_with_boxes, _roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) @@ -688,7 +690,7 @@ def test_stale_readback_filtering( def test_backend_update_from_another_view( - roi_plot_factory, detector_data, workflow_id, job_number + roi_plot_factory, data_service, detector_data, workflow_id, job_number ): """Test that backend updates from other clients are applied correctly.""" from ess.livedata.dashboard.roi_publisher import FakeROIPublisher @@ -703,12 +705,14 @@ def test_backend_update_from_another_view( output_name='current', ) + # Populate DataService with detector data + data_service[detector_key] = detector_data + params = PlotParamsROIDetector(plot_scale=PlotScaleParams2d()) _detector_with_boxes, _roi_dmap, plot_state = ( roi_plot_factory.create_roi_detector_plot_components( detector_key=detector_key, - detector_data=detector_data, params=params, ) ) diff --git a/tests/dashboard/stream_manager_test.py b/tests/dashboard/stream_manager_test.py index 6e3e736f4..8ae7037a3 100644 --- a/tests/dashboard/stream_manager_test.py +++ b/tests/dashboard/stream_manager_test.py @@ -7,6 +7,7 @@ import pytest import scipp as sc +from scipp.testing import assert_identical from ess.livedata.config.workflow_spec import JobId, ResultKey, WorkflowId from ess.livedata.dashboard.data_service import DataService @@ -14,6 +15,7 @@ Pipe, StreamAssembler, ) +from ess.livedata.dashboard.extractors import LatestValueExtractor from ess.livedata.dashboard.stream_manager import StreamManager @@ -76,6 +78,33 @@ def sample_data() -> sc.DataArray: ) +def assert_dict_equal_with_scipp(actual: dict, expected: dict) -> None: + """ + Assert two dictionaries are equal, using scipp.testing for scipp objects. + + Parameters + ---------- + actual: + The actual dictionary received. + expected: + The expected dictionary. + """ + assert ( + actual.keys() == expected.keys() + ), f"Keys differ: {actual.keys()} != {expected.keys()}" + for key in expected: + actual_value = actual[key] + expected_value = expected[key] + # Check if this is a scipp object by checking the module + if type(expected_value).__module__.startswith('scipp'): + # Use scipp.testing.assert_identical for scipp DataArrays and Datasets + assert_identical(actual_value, expected_value) + else: + assert ( + actual_value == expected_value + ), f"Value for key {key} differs: {actual_value} != {expected_value}" + + class TestStreamManager: """Test cases for base StreamManager class.""" @@ -95,8 +124,9 @@ def test_make_merging_stream_creates_pipe_and_registers_subscriber( job_id=JobId(source_name="source1", job_number=uuid.uuid4()), ) } + extractors = {key: LatestValueExtractor() for key in keys} - pipe = manager.make_merging_stream(keys) + pipe = manager.make_merging_stream(extractors) assert isinstance(pipe, FakePipe) assert fake_pipe_factory.call_count == 1 @@ -122,7 +152,8 @@ def test_partial_data_updates(self, data_service, fake_pipe_factory, sample_data ) keys = {key1, key2} - pipe = manager.make_merging_stream(keys) + extractors = {key: LatestValueExtractor() for key in keys} + pipe = manager.make_merging_stream(extractors) # Send data for only one key data_service[key1] = sample_data @@ -131,7 +162,7 @@ def test_partial_data_updates(self, data_service, fake_pipe_factory, sample_data assert len(pipe.send_calls) == 1 assert key1 in pipe.send_calls[0] assert key2 not in pipe.send_calls[0] - assert sc.identical(pipe.send_calls[0][key1], sample_data) + assert_identical(pipe.send_calls[0][key1], sample_data) def test_stream_independence(self, data_service, fake_pipe_factory, sample_data): """Test that multiple streams operate independently.""" @@ -153,8 +184,10 @@ def test_stream_independence(self, data_service, fake_pipe_factory, sample_data) job_id=JobId(source_name="source2", job_number=uuid.uuid4()), ) - pipe1 = manager.make_merging_stream({key1}) - pipe2 = manager.make_merging_stream({key2}) + extractors1 = {key1: LatestValueExtractor()} + extractors2 = {key2: LatestValueExtractor()} + pipe1 = manager.make_merging_stream(extractors1) + pipe2 = manager.make_merging_stream(extractors2) # Send data for key1 data_service[key1] = sample_data @@ -184,15 +217,16 @@ def test_single_source_data_flow( ), job_id=JobId(source_name="source1", job_number=uuid.uuid4()), ) + extractors = {key: LatestValueExtractor()} - pipe = manager.make_merging_stream({key}) + pipe = manager.make_merging_stream(extractors) # Publish data data_service[key] = sample_data # Verify data received assert len(pipe.send_calls) == 1 - assert pipe.send_calls[0] == {key: sample_data} + assert_dict_equal_with_scipp(pipe.send_calls[0], {key: sample_data}) def test_multiple_sources_data_flow( self, data_service, fake_pipe_factory, sample_data @@ -216,7 +250,8 @@ def test_multiple_sources_data_flow( ) keys = {key1, key2} - pipe = manager.make_merging_stream(keys) + extractors = {key: LatestValueExtractor() for key in keys} + pipe = manager.make_merging_stream(extractors) # Publish data for both keys sample_data2 = sc.DataArray( @@ -230,9 +265,11 @@ def test_multiple_sources_data_flow( # Should receive data for both keys assert len(pipe.send_calls) == 2 # First call has only key1 - assert pipe.send_calls[0] == {key1: sample_data} + assert_dict_equal_with_scipp(pipe.send_calls[0], {key1: sample_data}) # Second call has both keys - assert pipe.send_calls[1] == {key1: sample_data, key2: sample_data2} + assert_dict_equal_with_scipp( + pipe.send_calls[1], {key1: sample_data, key2: sample_data2} + ) def test_incremental_updates(self, data_service, fake_pipe_factory, sample_data): """Test that incremental updates flow through correctly.""" @@ -246,8 +283,9 @@ def test_incremental_updates(self, data_service, fake_pipe_factory, sample_data) ), job_id=JobId(source_name="source1", job_number=uuid.uuid4()), ) + extractors = {key: LatestValueExtractor()} - pipe = manager.make_merging_stream({key}) + pipe = manager.make_merging_stream(extractors) # Send initial data data_service[key] = sample_data @@ -261,8 +299,8 @@ def test_incremental_updates(self, data_service, fake_pipe_factory, sample_data) # Should receive both updates assert len(pipe.send_calls) == 2 - assert pipe.send_calls[0] == {key: sample_data} - assert pipe.send_calls[1] == {key: updated_data} + assert_dict_equal_with_scipp(pipe.send_calls[0], {key: sample_data}) + assert_dict_equal_with_scipp(pipe.send_calls[1], {key: updated_data}) def test_empty_source_set(self, data_service, fake_pipe_factory): """Test behavior with empty source set.""" @@ -271,7 +309,7 @@ def test_empty_source_set(self, data_service, fake_pipe_factory): ) # Create stream with empty key set - pipe = manager.make_merging_stream(set()) + pipe = manager.make_merging_stream([]) # Publish some data key = ResultKey( @@ -302,8 +340,9 @@ def test_shared_source_triggering( ) # Create two streams that both include the shared key - pipe1 = manager.make_merging_stream({shared_key}) - pipe2 = manager.make_merging_stream({shared_key}) + extractors = {shared_key: LatestValueExtractor()} + pipe1 = manager.make_merging_stream(extractors) + pipe2 = manager.make_merging_stream(extractors) # Publish data to shared key data_service[shared_key] = sample_data @@ -311,8 +350,8 @@ def test_shared_source_triggering( # Both pipes should receive the data assert len(pipe1.send_calls) == 1 assert len(pipe2.send_calls) == 1 - assert pipe1.send_calls[0] == {shared_key: sample_data} - assert pipe2.send_calls[0] == {shared_key: sample_data} + assert_dict_equal_with_scipp(pipe1.send_calls[0], {shared_key: sample_data}) + assert_dict_equal_with_scipp(pipe2.send_calls[0], {shared_key: sample_data}) def test_unrelated_key_filtering( self, data_service, fake_pipe_factory, sample_data @@ -336,8 +375,9 @@ def test_unrelated_key_filtering( ), job_id=JobId(source_name="unrelated_source", job_number=uuid.uuid4()), ) + extractors = {target_key: LatestValueExtractor()} - pipe = manager.make_merging_stream({target_key}) + pipe = manager.make_merging_stream(extractors) # Publish data for unrelated key data_service[unrelated_key] = sample_data @@ -350,7 +390,7 @@ def test_unrelated_key_filtering( # Should receive data now assert len(pipe.send_calls) == 1 - assert pipe.send_calls[0] == {target_key: sample_data} + assert_dict_equal_with_scipp(pipe.send_calls[0], {target_key: sample_data}) def test_complex_multi_stream_scenario(self, data_service, fake_pipe_factory): """Test complex scenario with multiple streams and overlapping keys.""" @@ -379,9 +419,12 @@ def test_complex_multi_stream_scenario(self, data_service, fake_pipe_factory): ) # Create streams with overlapping keys - pipe1 = manager.make_merging_stream({key_a, key_b}) # a, b - pipe2 = manager.make_merging_stream({key_b, key_c}) # b, c - pipe3 = manager.make_merging_stream({key_a}) # a only + extractors1 = {key_a: LatestValueExtractor(), key_b: LatestValueExtractor()} + extractors2 = {key_b: LatestValueExtractor(), key_c: LatestValueExtractor()} + extractors3 = {key_a: LatestValueExtractor()} + pipe1 = manager.make_merging_stream(extractors1) # a, b + pipe2 = manager.make_merging_stream(extractors2) # b, c + pipe3 = manager.make_merging_stream(extractors3) # a only # Create sample data data_a = sc.DataArray(data=sc.array(dims=[], values=[1])) @@ -395,26 +438,30 @@ def test_complex_multi_stream_scenario(self, data_service, fake_pipe_factory): # Verify pipe1 (keys a, b) assert len(pipe1.send_calls) == 2 - assert pipe1.send_calls[0] == {key_a: data_a} - assert pipe1.send_calls[1] == {key_a: data_a, key_b: data_b} + assert_dict_equal_with_scipp(pipe1.send_calls[0], {key_a: data_a}) + assert_dict_equal_with_scipp( + pipe1.send_calls[1], {key_a: data_a, key_b: data_b} + ) # Verify pipe2 (keys b, c) assert len(pipe2.send_calls) == 2 - assert pipe2.send_calls[0] == {key_b: data_b} - assert pipe2.send_calls[1] == {key_b: data_b, key_c: data_c} + assert_dict_equal_with_scipp(pipe2.send_calls[0], {key_b: data_b}) + assert_dict_equal_with_scipp( + pipe2.send_calls[1], {key_b: data_b, key_c: data_c} + ) # Verify pipe3 (key a only) assert len(pipe3.send_calls) == 1 - assert pipe3.send_calls[0] == {key_a: data_a} + assert_dict_equal_with_scipp(pipe3.send_calls[0], {key_a: data_a}) -class TestStreamManagerMakeMergingStreamFromKeys: - """Test cases for make_merging_stream_from_keys method.""" +class TestStreamManagerMakeMergingStreamWithoutInitialData: + """Test cases for make_merging_stream when subscribing to keys without data.""" - def test_make_merging_stream_from_keys_initializes_with_empty_dict( + def test_make_merging_stream_initializes_with_empty_dict( self, data_service, fake_pipe_factory ): - """Test that make_merging_stream_from_keys initializes pipe with empty dict.""" + """Test that make_merging_stream initializes pipe with empty dict.""" manager = StreamManager( data_service=data_service, pipe_factory=fake_pipe_factory ) @@ -427,14 +474,14 @@ def test_make_merging_stream_from_keys_initializes_with_empty_dict( ) keys = [key] - pipe = manager.make_merging_stream_from_keys(keys) + pipe = manager.make_merging_stream(keys) # Should initialize with empty dict assert isinstance(pipe, FakePipe) assert pipe.data == {} assert len(data_service._subscribers) == 1 - def test_make_merging_stream_from_keys_receives_data_when_available( + def test_make_merging_stream_receives_data_when_available( self, data_service, fake_pipe_factory, sample_data ): """Test that pipe receives data when it becomes available.""" @@ -451,7 +498,7 @@ def test_make_merging_stream_from_keys_receives_data_when_available( keys = [key] # Create stream first (no data yet) - pipe = manager.make_merging_stream_from_keys(keys) + pipe = manager.make_merging_stream(keys) # Initially empty assert pipe.data == {} @@ -461,9 +508,9 @@ def test_make_merging_stream_from_keys_receives_data_when_available( # Should receive data assert len(pipe.send_calls) == 1 - assert pipe.send_calls[0] == {key: sample_data} + assert_dict_equal_with_scipp(pipe.send_calls[0], {key: sample_data}) - def test_make_merging_stream_from_keys_with_multiple_keys( + def test_make_merging_stream_with_multiple_keys( self, data_service, fake_pipe_factory, sample_data ): """Test subscribing to multiple keys that don't have data yet.""" @@ -485,7 +532,7 @@ def test_make_merging_stream_from_keys_with_multiple_keys( ) keys = [key1, key2] - pipe = manager.make_merging_stream_from_keys(keys) + pipe = manager.make_merging_stream(keys) # Initially empty assert pipe.data == {} @@ -498,7 +545,7 @@ def test_make_merging_stream_from_keys_with_multiple_keys( assert key1 in pipe.send_calls[0] assert key2 not in pipe.send_calls[0] - def test_make_merging_stream_from_keys_uses_default_assembler( + def test_make_merging_stream_uses_default_assembler( self, data_service, fake_pipe_factory, sample_data ): """Test that default assembler is MergingStreamAssembler.""" @@ -513,24 +560,22 @@ def test_make_merging_stream_from_keys_uses_default_assembler( job_id=JobId(source_name="source", job_number=uuid.uuid4()), ) - pipe = manager.make_merging_stream_from_keys([key]) + pipe = manager.make_merging_stream([key]) # Publish data data_service[key] = sample_data # Should receive data (verifies default assembler works) assert len(pipe.send_calls) == 1 - assert pipe.send_calls[0] == {key: sample_data} + assert_dict_equal_with_scipp(pipe.send_calls[0], {key: sample_data}) - def test_make_merging_stream_from_keys_with_empty_list( - self, data_service, fake_pipe_factory - ): + def test_make_merging_stream_with_empty_list(self, data_service, fake_pipe_factory): """Test with empty keys list.""" manager = StreamManager( data_service=data_service, pipe_factory=fake_pipe_factory ) - pipe = manager.make_merging_stream_from_keys([]) + pipe = manager.make_merging_stream([]) # Should initialize with empty dict assert pipe.data == {} @@ -547,7 +592,7 @@ def test_make_merging_stream_from_keys_with_empty_list( # Should not receive any data assert len(pipe.send_calls) == 0 - def test_make_merging_stream_from_keys_roi_spectrum_use_case( + def test_make_merging_stream_roi_spectrum_use_case( self, data_service, fake_pipe_factory ): """Test ROI spectrum subscription (subscribe upfront, data comes later).""" @@ -571,7 +616,7 @@ def test_make_merging_stream_from_keys_roi_spectrum_use_case( for i in range(3) ] - pipe = manager.make_merging_stream_from_keys(keys) + pipe = manager.make_merging_stream(keys) # Initially empty assert pipe.data == {} diff --git a/tests/dashboard/temporal_buffer_manager_test.py b/tests/dashboard/temporal_buffer_manager_test.py new file mode 100644 index 000000000..89ae257a4 --- /dev/null +++ b/tests/dashboard/temporal_buffer_manager_test.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import pytest +import scipp as sc + +from ess.livedata.dashboard.extractors import ( + FullHistoryExtractor, + LatestValueExtractor, + WindowAggregatingExtractor, +) +from ess.livedata.dashboard.temporal_buffer_manager import TemporalBufferManager +from ess.livedata.dashboard.temporal_buffers import SingleValueBuffer, TemporalBuffer + + +class TestTemporalBufferManager: + """Tests for TemporalBufferManager.""" + + def test_create_buffer_with_only_latest_extractors_uses_single_value_buffer(self): + """ + Test that SingleValueBuffer is used with all LatestValueExtractor. + """ + manager = TemporalBufferManager() + extractors = [LatestValueExtractor(), LatestValueExtractor()] + + manager.create_buffer('test', extractors) + + assert isinstance(manager['test'], SingleValueBuffer) + + def test_create_buffer_with_mixed_extractors_uses_temporal_buffer(self): + """ + Test that TemporalBuffer is used with mixed extractors. + """ + manager = TemporalBufferManager() + extractors = [LatestValueExtractor(), FullHistoryExtractor()] + + manager.create_buffer('test', extractors) + + assert isinstance(manager['test'], TemporalBuffer) + + def test_create_buffer_with_window_extractor_uses_temporal_buffer(self): + """Test that TemporalBuffer is used with WindowAggregatingExtractor.""" + manager = TemporalBufferManager() + extractors = [WindowAggregatingExtractor(window_duration_seconds=1.0)] + + manager.create_buffer('test', extractors) + + assert isinstance(manager['test'], TemporalBuffer) + + def test_create_buffer_with_no_extractors_uses_single_value_buffer(self): + """ + Test that SingleValueBuffer is used by default with no extractors. + """ + manager = TemporalBufferManager() + + manager.create_buffer('test', []) + + assert isinstance(manager['test'], SingleValueBuffer) + + def test_create_buffer_raises_error_for_duplicate_key(self): + """Test that creating a buffer with existing key raises ValueError.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + + manager.create_buffer('test', extractors) + + with pytest.raises(ValueError, match="already exists"): + manager.create_buffer('test', extractors) + + def test_update_buffer_adds_data(self): + """Test that update_buffer adds data to the buffer.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + data = sc.DataArray( + sc.scalar(42, unit='counts'), + coords={'time': sc.scalar(1.0, unit='s')}, + ) + + manager.create_buffer('test', extractors) + manager.update_buffer('test', data) + + result = manager.get_buffered_data('test') + assert sc.identical(result, data) + + def test_update_buffer_raises_error_for_missing_key(self): + """Test that updating non-existent buffer raises KeyError.""" + manager = TemporalBufferManager() + data = sc.DataArray( + sc.scalar(42, unit='counts'), + coords={'time': sc.scalar(1.0, unit='s')}, + ) + + with pytest.raises(KeyError, match="No buffer found"): + manager.update_buffer('test', data) + + def test_add_extractor_keeps_same_buffer_type(self): + """Test that adding compatible extractor keeps same buffer type.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + + manager.create_buffer('test', extractors) + original_buffer = manager['test'] + + manager.add_extractor('test', LatestValueExtractor()) + + assert manager['test'] is original_buffer + assert isinstance(manager['test'], SingleValueBuffer) + + def test_add_extractor_switches_to_temporal_buffer(self): + """Test that switching buffer types preserves existing data.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + data = sc.DataArray( + sc.array(dims=['x'], values=[1.0, 2.0], unit='counts'), + coords={ + 'x': sc.arange('x', 2, unit='m'), + 'time': sc.scalar(1.0, unit='s'), + }, + ) + + manager.create_buffer('test', extractors) + manager.update_buffer('test', data) + + # Add full history extractor - should trigger buffer type switch + manager.add_extractor('test', FullHistoryExtractor()) + + # Data should be preserved when switching + result = manager.get_buffered_data('test') + assert result is not None + # After switching, buffer transforms scalar time coord to time dimension + assert 'time' in result.dims + assert result.sizes['time'] == 1 + # Verify the data values are preserved + assert sc.allclose(result['time', 0].data, data.data) + + def test_set_extractors_switches_to_single_value_buffer(self): + """Test that switching buffer types preserves latest data.""" + manager = TemporalBufferManager() + extractors = [WindowAggregatingExtractor(window_duration_seconds=1.0)] + + manager.create_buffer('test', extractors) + + # Add multiple time slices + for i in range(3): + data = sc.DataArray( + sc.array(dims=['x'], values=[float(i)] * 2, unit='counts'), + coords={ + 'x': sc.arange('x', 2, unit='m'), + 'time': sc.scalar(float(i), unit='s'), + }, + ) + manager.update_buffer('test', data) + + # Verify we have temporal data with 3 time points + result = manager.get_buffered_data('test') + assert result is not None + assert 'time' in result.dims + assert result.sizes['time'] == 3 + + # Replace extractors - this should trigger buffer type switch + manager.set_extractors('test', [LatestValueExtractor()]) + + # Verify the latest time slice is preserved after transition + result = manager.get_buffered_data('test') + assert result is not None + # The last slice should have values [2.0, 2.0] and time=2.0 + expected_data = sc.array(dims=['x'], values=[2.0, 2.0], unit='counts') + assert sc.allclose(result.data, expected_data) + assert result.coords['time'].value == 2.0 + + def test_add_extractor_raises_error_for_missing_key(self): + """Test that adding extractor to non-existent buffer raises KeyError.""" + manager = TemporalBufferManager() + + with pytest.raises(KeyError): + manager.add_extractor('test', LatestValueExtractor()) + + def test_delete_buffer_removes_buffer(self): + """Test that delete_buffer removes the buffer.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + + manager.create_buffer('test', extractors) + assert 'test' in manager + + manager.delete_buffer('test') + assert 'test' not in manager + + def test_delete_buffer_nonexistent_key_does_nothing(self): + """Test that deleting non-existent buffer doesn't raise error.""" + manager = TemporalBufferManager() + manager.delete_buffer('nonexistent') # Should not raise + + def test_mapping_interface(self): + """Test that manager implements Mapping interface correctly.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + + manager.create_buffer('key1', extractors) + manager.create_buffer('key2', extractors) + + assert len(manager) == 2 + assert 'key1' in manager + assert 'key2' in manager + assert list(manager) == ['key1', 'key2'] + + +class TestTemporalBufferManagerTimespanPropagation: + """Tests for timespan requirement propagation.""" + + def test_window_extractor_sets_timespan_on_buffer(self): + """Test that WindowAggregatingExtractor's timespan is set on buffer.""" + manager = TemporalBufferManager() + window_duration = 5.0 + extractors = [ + WindowAggregatingExtractor(window_duration_seconds=window_duration) + ] + + manager.create_buffer('test', extractors) + + buffer = manager['test'] + assert isinstance(buffer, TemporalBuffer) + assert buffer.get_required_timespan() == window_duration + + def test_multiple_window_extractors_use_max_timespan(self): + """Test that maximum timespan from multiple extractors is used.""" + manager = TemporalBufferManager() + extractors = [ + WindowAggregatingExtractor(window_duration_seconds=3.0), + WindowAggregatingExtractor(window_duration_seconds=5.0), + WindowAggregatingExtractor(window_duration_seconds=2.0), + ] + + manager.create_buffer('test', extractors) + + buffer = manager['test'] + assert buffer.get_required_timespan() == 5.0 + + def test_latest_extractor_does_not_set_timespan(self): + """Test that LatestValueExtractor doesn't set a timespan.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + + manager.create_buffer('test', extractors) + + buffer = manager['test'] + assert isinstance(buffer, SingleValueBuffer) + assert buffer.get_required_timespan() == 0.0 + + def test_mixed_extractors_use_window_timespan(self): + """Test that timespan is set when mixing Latest and Window extractors.""" + manager = TemporalBufferManager() + extractors = [ + LatestValueExtractor(), + WindowAggregatingExtractor(window_duration_seconds=4.0), + ] + + manager.create_buffer('test', extractors) + + buffer = manager['test'] + assert isinstance(buffer, TemporalBuffer) + assert buffer.get_required_timespan() == 4.0 + + def test_adding_extractor_updates_timespan(self): + """Test that adding an extractor updates the buffer's timespan.""" + manager = TemporalBufferManager() + extractors = [WindowAggregatingExtractor(window_duration_seconds=2.0)] + + manager.create_buffer('test', extractors) + buffer = manager['test'] + assert buffer.get_required_timespan() == 2.0 + + # Add extractor with larger timespan + manager.add_extractor( + 'test', WindowAggregatingExtractor(window_duration_seconds=10.0) + ) + + assert buffer.get_required_timespan() == 10.0 + + def test_full_history_extractor_infinite_timespan(self): + """Test that FullHistoryExtractor sets infinite timespan.""" + manager = TemporalBufferManager() + extractors = [FullHistoryExtractor()] + + manager.create_buffer('test', extractors) + + buffer = manager['test'] + assert isinstance(buffer, TemporalBuffer) + assert buffer.get_required_timespan() == float('inf') + + def test_full_history_with_window_uses_infinite(self): + """Test that mixing FullHistory with Window uses infinite timespan.""" + manager = TemporalBufferManager() + extractors = [ + WindowAggregatingExtractor(window_duration_seconds=5.0), + FullHistoryExtractor(), + ] + + manager.create_buffer('test', extractors) + + buffer = manager['test'] + assert isinstance(buffer, TemporalBuffer) + # max(5.0, inf) = inf + assert buffer.get_required_timespan() == float('inf') + + +class TestTemporalBufferManagerWithRealData: + """Integration tests with real scipp data.""" + + def test_single_value_buffer_workflow(self): + """Test complete workflow with SingleValueBuffer.""" + manager = TemporalBufferManager() + extractors = [LatestValueExtractor()] + + manager.create_buffer('stream', extractors) + + # Add multiple values + for i in range(3): + data = sc.DataArray( + sc.array(dims=['x'], values=[float(i)] * 2, unit='counts'), + coords={ + 'x': sc.arange('x', 2, unit='m'), + 'time': sc.scalar(float(i), unit='s'), + }, + ) + manager.update_buffer('stream', data) + + # Should only have latest value + result = manager.get_buffered_data('stream') + assert result is not None + # Extract using the extractor + extracted = extractors[0].extract(result) + expected = sc.array(dims=['x'], values=[2.0, 2.0], unit='counts') + assert sc.allclose(extracted.data, expected) + + def test_temporal_buffer_workflow(self): + """Test complete workflow with TemporalBuffer.""" + manager = TemporalBufferManager() + extractors = [WindowAggregatingExtractor(window_duration_seconds=5.0)] + + manager.create_buffer('stream', extractors) + + # Add multiple time slices + for i in range(3): + data = sc.DataArray( + sc.array(dims=['x'], values=[float(i)] * 2, unit='counts'), + coords={ + 'x': sc.arange('x', 2, unit='m'), + 'time': sc.scalar(float(i), unit='s'), + }, + ) + manager.update_buffer('stream', data) + + # Should have all data concatenated + result = manager.get_buffered_data('stream') + assert result is not None + assert 'time' in result.dims + assert result.sizes['time'] == 3 diff --git a/tests/dashboard/temporal_buffers_test.py b/tests/dashboard/temporal_buffers_test.py new file mode 100644 index 000000000..9dce08d1e --- /dev/null +++ b/tests/dashboard/temporal_buffers_test.py @@ -0,0 +1,575 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import pytest +import scipp as sc + +from ess.livedata.dashboard.temporal_buffers import ( + SingleValueBuffer, + TemporalBuffer, + VariableBuffer, +) + + +# Fixtures for common test data +@pytest.fixture +def single_slice_2element(): + """Create a single time slice with 2 x-elements.""" + return sc.DataArray( + sc.array(dims=['x'], values=[1.0, 2.0], unit='counts'), + coords={ + 'x': sc.arange('x', 2, unit='m'), + 'time': sc.scalar(0.0, unit='s'), + }, + ) + + +@pytest.fixture +def thick_slice_2x2(): + """Create a thick slice with 2 time points and 2 x-elements.""" + return sc.DataArray( + sc.array(dims=['time', 'x'], values=[[1.0, 2.0], [3.0, 4.0]], unit='counts'), + coords={ + 'x': sc.arange('x', 2, unit='m'), + 'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'), + }, + ) + + +# Helper functions for creating test data +def make_single_slice(x_values, time_value): + """Create a single time slice DataArray.""" + return sc.DataArray( + sc.array(dims=['x'], values=x_values, unit='counts'), + coords={ + 'x': sc.arange('x', len(x_values), unit='m'), + 'time': sc.scalar(time_value, unit='s'), + }, + ) + + +def make_thick_slice(x_size, time_values): + """Create a thick slice DataArray with multiple time points.""" + n_times = len(time_values) + return sc.DataArray( + sc.array( + dims=['time', 'x'], + values=[[float(i)] * x_size for i in range(n_times)], + unit='counts', + ), + coords={ + 'x': sc.arange('x', x_size, unit='m'), + 'time': sc.array(dims=['time'], values=time_values, unit='s'), + }, + ) + + +def assert_buffer_has_time_data(buffer, expected_size): + """Assert buffer contains time-dimensioned data of expected size.""" + result = buffer.get() + assert result is not None + assert 'time' in result.dims + assert result.sizes['time'] == expected_size + + +class TestSingleValueBuffer: + """Tests for SingleValueBuffer.""" + + def test_add_and_get_scalar(self): + """Test adding and retrieving a scalar value.""" + buffer = SingleValueBuffer() + data = sc.scalar(42, unit='counts') + + buffer.add(data) + result = buffer.get() + + assert result == data + + def test_add_replaces_previous_value(self): + """Test that add replaces the previous value.""" + buffer = SingleValueBuffer() + data1 = sc.scalar(10, unit='counts') + data2 = sc.scalar(20, unit='counts') + + buffer.add(data1) + buffer.add(data2) + result = buffer.get() + + assert result == data2 + + def test_get_empty_buffer_returns_none(self): + """Test that get returns None for empty buffer.""" + buffer = SingleValueBuffer() + assert buffer.get() is None + + def test_clear_removes_value(self): + """Test that clear removes the stored value.""" + buffer = SingleValueBuffer() + data = sc.scalar(42, unit='counts') + + buffer.add(data) + buffer.clear() + result = buffer.get() + + assert result is None + + def test_set_required_timespan_does_not_error(self): + """Test that set_required_timespan can be called without error.""" + buffer = SingleValueBuffer() + buffer.set_required_timespan(10.0) # Should not raise + + def test_set_max_memory_does_not_error(self): + """Test that set_max_memory can be called without error.""" + buffer = SingleValueBuffer() + buffer.set_max_memory(1000) # Should not raise + + def test_add_dataarray_with_dimensions(self): + """Test adding a DataArray with dimensions.""" + buffer = SingleValueBuffer() + data = sc.DataArray( + sc.array(dims=['x'], values=[1.0, 2.0, 3.0], unit='counts'), + coords={'x': sc.arange('x', 3, unit='m')}, + ) + + buffer.add(data) + result = buffer.get() + + assert sc.identical(result, data) + + +class TestTemporalBuffer: + """Tests for TemporalBuffer.""" + + @pytest.mark.parametrize( + ('data_creator', 'expected_time_size'), + [ + (lambda: make_single_slice([1.0, 2.0, 3.0], 0.0), 1), + (lambda: make_thick_slice(2, [0.0, 1.0]), 2), + ], + ids=['single_slice', 'thick_slice'], + ) + def test_add_data_creates_time_dimension(self, data_creator, expected_time_size): + """Test that adding data creates buffer with time dimension.""" + buffer = TemporalBuffer() + buffer.add(data_creator()) + assert_buffer_has_time_data(buffer, expected_time_size) + + def test_add_multiple_single_slices(self): + """Test concatenating multiple single slices.""" + buffer = TemporalBuffer() + + for i in range(3): + buffer.add(make_single_slice([float(i)] * 2, float(i))) + + assert_buffer_has_time_data(buffer, 3) + + def test_add_multiple_thick_slices(self): + """Test concatenating multiple thick slices.""" + buffer = TemporalBuffer() + buffer.add(make_thick_slice(2, [0.0, 1.0])) + buffer.add(make_thick_slice(2, [2.0, 3.0])) + assert_buffer_has_time_data(buffer, 4) + + def test_add_mixed_single_and_thick_slices(self): + """Test concatenating mixed single and thick slices.""" + buffer = TemporalBuffer() + buffer.add(make_single_slice([1.0, 2.0], 0.0)) + buffer.add(make_thick_slice(2, [1.0, 2.0])) + assert_buffer_has_time_data(buffer, 3) + + def test_add_without_time_coord_raises_error(self): + """Test that adding data without time coordinate raises ValueError.""" + buffer = TemporalBuffer() + data = sc.DataArray( + sc.array(dims=['x'], values=[1.0, 2.0, 3.0], unit='counts'), + coords={'x': sc.arange('x', 3, unit='m')}, + ) + + with pytest.raises(ValueError, match="requires data with 'time' coordinate"): + buffer.add(data) + + def test_get_empty_buffer_returns_none(self): + """Test that get returns None for empty buffer.""" + buffer = TemporalBuffer() + assert buffer.get() is None + + def test_clear_removes_all_data(self): + """Test that clear removes all buffered data.""" + buffer = TemporalBuffer() + buffer.add(make_single_slice([1.0, 2.0], 0.0)) + buffer.clear() + assert buffer.get() is None + + def test_timespan_trimming_on_capacity_failure(self): + """Test that old data is trimmed when capacity is reached.""" + buffer = TemporalBuffer() + buffer.set_required_timespan(5.0) # Keep last 5 seconds + buffer.set_max_memory(100) # Small memory limit to trigger trimming + + # Add data at t=0 + buffer.add(make_single_slice([1.0, 2.0], 0.0)) + initial_capacity = buffer._data_buffer.max_capacity + + # Fill buffer close to capacity with data at t=1, 2, 3, 4 + for t in range(1, initial_capacity): + buffer.add(make_single_slice([float(t), float(t)], float(t))) + + # Add data at t=10 (outside timespan from t=0-4) + # This should trigger trimming of old data + buffer.add(make_single_slice([10.0, 10.0], 10.0)) + + result = buffer.get() + # Only data from t >= 5.0 should remain (t=10 - 5.0) + # So only t=10 should be in buffer (since we only added up to t=capacity-1) + assert result.coords['time'].values[0] >= 5.0 + + def test_no_trimming_when_capacity_available(self): + """Test that trimming doesn't occur when there's available capacity.""" + buffer = TemporalBuffer() + buffer.set_required_timespan(2.0) # Keep last 2 seconds + + # Add data at t=0, 1, 2, 3, 4, 5 + for t in range(6): + buffer.add(make_single_slice([float(t), float(t)], float(t))) + + result = buffer.get() + # With default large capacity (10000), no trimming should occur + # All 6 time points should be present despite timespan=2.0 + assert result.sizes['time'] == 6 + assert result.coords['time'].values[0] == 0.0 + + def test_trim_drops_all_old_data(self): + """Test trimming when all buffered data is older than timespan.""" + buffer = TemporalBuffer() + buffer.set_required_timespan(1.0) + buffer.set_max_memory(50) # Very small to trigger trim quickly + + # Add data at t=0 + buffer.add(make_single_slice([1.0, 2.0], 0.0)) + + # Fill to capacity + capacity = buffer._data_buffer.max_capacity + for t in range(1, capacity): + buffer.add(make_single_slice([float(t), float(t)], float(t))) + + # Add data far in future, all previous data should be dropped + buffer.add(make_single_slice([99.0, 99.0], 100.0)) + + result = buffer.get() + # Only data >= 99.0 should remain (100 - 1.0 timespan) + assert result.coords['time'].values[0] >= 99.0 + + def test_capacity_exceeded_even_after_trimming_raises(self): + """Test that ValueError is raised if data exceeds capacity even after trim.""" + buffer = TemporalBuffer() + buffer.set_required_timespan(1.0) + buffer.set_max_memory(20) # Very small capacity (~ 1 element) + + # Add first data point + buffer.add(make_single_slice([1.0, 2.0], 0.0)) + + # Try to add thick slice that exceeds capacity + large_data = make_thick_slice(2, list(range(10))) + + with pytest.raises(ValueError, match="exceeds buffer capacity even after"): + buffer.add(large_data) + + def test_timespan_zero_trims_all_old_data_on_overflow(self): + """Test that timespan=0.0 trims all data to make room for new data.""" + buffer = TemporalBuffer() + buffer.set_required_timespan(0.0) # Keep only latest value + buffer.set_max_memory(100) # Small memory limit to force overflow + + # Add first data point + buffer.add(make_single_slice([1.0, 2.0], 0.0)) + initial_capacity = buffer._data_buffer.max_capacity + + # Fill buffer to capacity + for t in range(1, initial_capacity): + buffer.add(make_single_slice([float(t), float(t)], float(t))) + + # Buffer is now full, verify it has all data + result = buffer.get() + assert result.sizes['time'] == initial_capacity + + # Add one more data point - should trigger trimming + # With timespan=0.0, should drop ALL old data to make room + buffer.add(make_single_slice([999.0, 999.0], 999.0)) + + # Should only have the latest value + result = buffer.get() + assert result.sizes['time'] == 1 + assert result.coords['time'].values[0] == 999.0 + assert result['time', 0].values[0] == 999.0 + + +class TestVariableBuffer: + """Tests for VariableBuffer.""" + + def test_init_with_single_slice(self): + """Test initialization with single slice (no concat_dim).""" + data = sc.array(dims=['x'], values=[1.0, 2.0, 3.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + assert buffer.size == 1 + assert buffer.max_capacity == 10 + result = buffer.get() + assert result.sizes['time'] == 1 + assert sc.identical(result['time', 0], data) + + def test_init_with_thick_slice(self, thick_slice_2x2): + """Test initialization with thick slice (has concat_dim).""" + data = thick_slice_2x2.data # Extract the raw array + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + assert buffer.size == 2 + assert buffer.max_capacity == 10 + result = buffer.get() + assert result.sizes['time'] == 2 + assert sc.identical(result, data) + + def test_dimension_ordering_single_slice_makes_concat_dim_outer(self): + """Test that concat_dim becomes outer dimension for single slice.""" + # 2D image without time dimension + image = sc.array(dims=['y', 'x'], values=[[1, 2], [3, 4]], unit='counts') + buffer = VariableBuffer(data=image, max_capacity=10, concat_dim='time') + + # Buffer should have dims: time, y, x (time is outer) + assert list(buffer._buffer.dims) == ['time', 'y', 'x'] + + def test_dimension_ordering_thick_slice_preserves_order(self): + """Test that existing dimension order is preserved for thick slice.""" + # Data with time already in the middle + data = sc.array( + dims=['y', 'time', 'x'], values=[[[1, 2], [3, 4]]], unit='counts' + ) + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + # Buffer should preserve dimension order + assert list(buffer._buffer.dims) == ['y', 'time', 'x'] + + def test_append_single_slice(self): + """Test appending single slices.""" + data = sc.array(dims=['x'], values=[1.0, 2.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + # Append more slices + data2 = sc.array(dims=['x'], values=[3.0, 4.0], unit='counts') + data3 = sc.array(dims=['x'], values=[5.0, 6.0], unit='counts') + + assert buffer.append(data2) + assert buffer.append(data3) + + assert buffer.size == 3 + result = buffer.get() + assert result.sizes['time'] == 3 + assert sc.identical(result['time', 0], data) + assert sc.identical(result['time', 1], data2) + assert sc.identical(result['time', 2], data3) + + def test_append_thick_slice(self): + """Test appending thick slices.""" + data = sc.array(dims=['time', 'x'], values=[[1.0, 2.0]], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + # Append thick slice + data2 = sc.array( + dims=['time', 'x'], values=[[3.0, 4.0], [5.0, 6.0]], unit='counts' + ) + assert buffer.append(data2) + + assert buffer.size == 3 + result = buffer.get() + assert result.sizes['time'] == 3 + + def test_capacity_expansion(self): + """Test that buffer capacity expands as needed.""" + data = sc.array(dims=['x'], values=[1.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=100, concat_dim='time') + + initial_capacity = buffer.capacity + assert initial_capacity == 16 # min(16, max_capacity) + + # Append until we exceed initial capacity + for i in range(20): + assert buffer.append(sc.array(dims=['x'], values=[float(i)], unit='counts')) + + # Capacity should have expanded + assert buffer.capacity > initial_capacity + assert buffer.size == 21 + + def test_large_append_requires_multiple_expansions(self): + """Test appending data much larger than current capacity (bug fix).""" + data = sc.array(dims=['x'], values=[1.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=200, concat_dim='time') + + assert buffer.capacity == 16 + + # Append 100 elements at once (requires multiple doublings: 16->32->64->128) + large_data = sc.array( + dims=['time', 'x'], values=[[float(i)] for i in range(100)], unit='counts' + ) + assert buffer.append(large_data) + + assert buffer.size == 101 + assert buffer.capacity >= 101 + result = buffer.get() + assert result.sizes['time'] == 101 + + def test_append_exceeding_max_capacity_fails(self): + """Test that append fails when exceeding max_capacity.""" + data = sc.array(dims=['x'], values=[1.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=5, concat_dim='time') + + # Fill to max_capacity + for i in range(4): + assert buffer.append(sc.array(dims=['x'], values=[float(i)], unit='counts')) + + assert buffer.size == 5 + + # Next append should fail + assert not buffer.append(sc.array(dims=['x'], values=[99.0], unit='counts')) + assert buffer.size == 5 # Size unchanged + + def test_init_exceeding_max_capacity_raises(self): + """Test that initialization with data exceeding max_capacity raises.""" + data = sc.array(dims=['time', 'x'], values=[[1.0], [2.0], [3.0]], unit='counts') + + with pytest.raises(ValueError, match="exceeds max_capacity"): + VariableBuffer(data=data, max_capacity=2, concat_dim='time') + + def test_get_returns_valid_data_only(self): + """Test that get returns only valid data, not full buffer capacity.""" + data = sc.array(dims=['x'], values=[1.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=100, concat_dim='time') + + # Capacity is 16, but size is 1 + assert buffer.capacity == 16 + assert buffer.size == 1 + + result = buffer.get() + assert result.sizes['time'] == 1 # Not 16 + + def test_drop_from_start(self): + """Test dropping data from the start.""" + data = sc.array( + dims=['time', 'x'], + values=[[1.0], [2.0], [3.0], [4.0], [5.0]], + unit='counts', + ) + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + # Drop first 2 elements + buffer.drop(2) + + assert buffer.size == 3 + result = buffer.get() + assert result.sizes['time'] == 3 + assert result.values[0, 0] == 3.0 + assert result.values[2, 0] == 5.0 + + def test_drop_all(self): + """Test dropping all data.""" + data = sc.array(dims=['time', 'x'], values=[[1.0], [2.0]], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + buffer.drop(5) # Drop more than size + + assert buffer.size == 0 + + def test_drop_zero_does_nothing(self): + """Test that dropping zero elements does nothing.""" + data = sc.array(dims=['time', 'x'], values=[[1.0], [2.0]], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + buffer.drop(0) + + assert buffer.size == 2 + + def test_drop_negative_does_nothing(self): + """Test that dropping negative index does nothing.""" + data = sc.array(dims=['time', 'x'], values=[[1.0], [2.0]], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='time') + + buffer.drop(-1) + + assert buffer.size == 2 + + def test_multidimensional_data(self): + """Test with multidimensional data (images).""" + # 2D image + image1 = sc.array(dims=['y', 'x'], values=[[1, 2], [3, 4]], unit='counts') + buffer = VariableBuffer(data=image1, max_capacity=10, concat_dim='time') + + image2 = sc.array(dims=['y', 'x'], values=[[5, 6], [7, 8]], unit='counts') + buffer.append(image2) + + result = buffer.get() + assert result.sizes == {'time': 2, 'y': 2, 'x': 2} + assert result.values[0, 0, 0] == 1 + assert result.values[1, 1, 1] == 8 + + def test_properties(self): + """Test buffer properties.""" + data = sc.array(dims=['time', 'x'], values=[[1.0], [2.0]], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=50, concat_dim='time') + + assert buffer.size == 2 + assert buffer.max_capacity == 50 + assert buffer.capacity == 16 # Initial allocation + + def test_custom_concat_dim(self): + """Test using a custom concat dimension.""" + data = sc.array(dims=['x'], values=[1.0, 2.0], unit='counts') + buffer = VariableBuffer(data=data, max_capacity=10, concat_dim='event') + + assert buffer.size == 1 + result = buffer.get() + assert 'event' in result.dims + assert result.sizes['event'] == 1 + + def test_scalar_to_1d(self): + """Test stacking scalars into 1D array.""" + scalar = sc.scalar(42.0, unit='counts') + buffer = VariableBuffer(data=scalar, max_capacity=10, concat_dim='time') + + buffer.append(sc.scalar(43.0, unit='counts')) + buffer.append(sc.scalar(44.0, unit='counts')) + + result = buffer.get() + assert result.sizes == {'time': 3} + assert list(result.values) == [42.0, 43.0, 44.0] + + def test_append_with_variances_preserves_variances_on_expansion(self): + """Test that variances are preserved when buffer expands. + + Regression test for bug where expanding the buffer would lose variances + because sc.empty() wasn't called with with_variances flag. + """ + # Create data with variances + data = sc.array( + dims=['x'], values=[1.0, 2.0], variances=[0.1, 0.2], unit='counts' + ) + buffer = VariableBuffer(data=data, max_capacity=100, concat_dim='time') + + # Verify initial data has variances + assert buffer.get().variances is not None + + # Append enough single slices to trigger multiple buffer expansions + # Initial capacity is 16, so we need > 16 appends + for i in range(20): + new_data = sc.array( + dims=['x'], + values=[float(i + 3), float(i + 4)], + variances=[0.3 + i * 0.01, 0.4 + i * 0.01], + unit='counts', + ) + assert buffer.append(new_data), f"Failed to append slice {i}" + + # Verify final result still has variances after expansion + result = buffer.get() + assert result.variances is not None + assert result.sizes['time'] == 21 + # Verify variances were preserved from at least one slice + assert result.variances[0, 0] > 0 diff --git a/tests/handlers/detector_view_test.py b/tests/handlers/detector_view_test.py index 862c0248c..e72025885 100644 --- a/tests/handlers/detector_view_test.py +++ b/tests/handlers/detector_view_test.py @@ -169,19 +169,90 @@ class TestDetectorViewBasics: """Basic tests for DetectorView without ROI.""" def test_detector_view_initialization( - self, mock_rolling_view: RollingDetectorView + self, + mock_rolling_view: RollingDetectorView, + sample_detector_events: sc.DataArray, ) -> None: """Test that DetectorView can be initialized with default parameters.""" params = DetectorViewParams() view = DetectorView(params=params, detector_view=mock_rolling_view) assert view is not None + # Accumulate some detector data before finalize + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) # Verify no ROI results are present when no ROI configured result = view.finalize() assert 'cumulative' in result assert 'current' in result assert not any(key.startswith('roi_') for key in result) + def test_current_has_time_coord( + self, + mock_rolling_view: RollingDetectorView, + sample_detector_events: sc.DataArray, + ) -> None: + """Test that 'current' result has time coord from first accumulate call.""" + params = DetectorViewParams() + view = DetectorView(params=params, detector_view=mock_rolling_view) + + # Accumulate with specific start_time + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) + result = view.finalize() + + # Verify time coord is present on current + assert 'time' in result['current'].coords + assert result['current'].coords['time'].value == 1000 + assert result['current'].coords['time'].unit == 'ns' + # cumulative should not have time coord + assert 'time' not in result['cumulative'].coords + + def test_time_coord_tracks_first_accumulate( + self, + mock_rolling_view: RollingDetectorView, + sample_detector_events: sc.DataArray, + ) -> None: + """Test that time coord uses first accumulate start_time, not later ones.""" + params = DetectorViewParams() + view = DetectorView(params=params, detector_view=mock_rolling_view) + + # First accumulate with start_time=1000 + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) + # Second accumulate with different start_time + view.accumulate( + {'detector': sample_detector_events}, start_time=3000, end_time=4000 + ) + + result = view.finalize() + + # Time should be from first accumulate + assert result['current'].coords['time'].value == 1000 + + # After finalize, next accumulate should track new start_time + view.accumulate( + {'detector': sample_detector_events}, start_time=5000, end_time=6000 + ) + result2 = view.finalize() + assert result2['current'].coords['time'].value == 5000 + + def test_finalize_without_accumulate_raises( + self, mock_rolling_view: RollingDetectorView + ) -> None: + """Test that finalize raises if called without accumulate.""" + params = DetectorViewParams() + view = DetectorView(params=params, detector_view=mock_rolling_view) + + with pytest.raises( + RuntimeError, + match="finalize called without any detector data accumulated", + ): + view.finalize() + def test_accumulate_detector_data_without_roi( self, mock_rolling_view: RollingDetectorView, @@ -192,7 +263,9 @@ def test_accumulate_detector_data_without_roi( view = DetectorView(params=params, detector_view=mock_rolling_view) # Accumulate detector data - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) # Finalize should return cumulative and current counts result = view.finalize() @@ -214,31 +287,45 @@ def test_clear_resets_state( params = DetectorViewParams() view = DetectorView(params=params, detector_view=mock_rolling_view) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result1 = view.finalize() assert sc.sum(result1['current']).value == TOTAL_SAMPLE_EVENTS # Clear the view view.clear() - # After clear, finalize should return zero counts + # After clear, accumulate new data then finalize should return zero counts + view.accumulate( + {'detector': sample_detector_events}, start_time=2000, end_time=3000 + ) result2 = view.finalize() - assert sc.sum(result2['current']).value == 0 - assert sc.sum(result2['cumulative']).value == 0 + assert sc.sum(result2['current']).value == TOTAL_SAMPLE_EVENTS + assert sc.sum(result2['cumulative']).value == TOTAL_SAMPLE_EVENTS class TestDetectorViewROIMechanism: """Tests for ROI configuration and histogram accumulation.""" def test_roi_configuration_via_accumulate( - self, mock_rolling_view: RollingDetectorView, standard_roi: RectangleROI + self, + mock_rolling_view: RollingDetectorView, + standard_roi: RectangleROI, + sample_detector_events: sc.DataArray, ) -> None: """Test that ROI configuration can be set via accumulate.""" params = DetectorViewParams() view = DetectorView(params=params, detector_view=mock_rolling_view) # Send ROI configuration - view.accumulate(roi_to_accumulate_data(standard_roi)) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) + # Add detector data before finalize + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) # Verify ROI configuration is active via finalize output result = view.finalize() @@ -260,23 +347,32 @@ def test_roi_configuration_via_accumulate( assert echoed_roi.y.max == 25.0 def test_roi_only_does_not_process_events( - self, mock_rolling_view: RollingDetectorView, standard_roi: RectangleROI + self, + mock_rolling_view: RollingDetectorView, + standard_roi: RectangleROI, + sample_detector_events: sc.DataArray, ) -> None: - """Test that sending only ROI (no detector data) produces empty histograms.""" + """Test that ROI configuration persists when detector data arrives.""" params = DetectorViewParams() view = DetectorView(params=params, detector_view=mock_rolling_view) # Send ROI configuration without detector data - view.accumulate(roi_to_accumulate_data(standard_roi)) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) - # ROI should be configured but no histogram data accumulated yet + # Now send detector data + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) + + # ROI should be configured and histogram data accumulated result = view.finalize() - # ROI results should be present (even if empty/zero) once ROI is configured + # ROI results should be present assert_has_roi_results(result) assert_roi_config_published(result) - # All counts should be zero since no events were accumulated - assert sc.sum(result['roi_cumulative_0']).value == 0 - assert sc.sum(result['roi_current_0']).value == 0 + # Counts should match events within ROI + assert_roi_event_count(result, STANDARD_ROI_EVENTS) def test_accumulate_with_roi_produces_histogram( self, @@ -290,10 +386,14 @@ def test_accumulate_with_roi_produces_histogram( view = DetectorView(params=params, detector_view=mock_rolling_view) # Configure ROI first - view.accumulate(roi_to_accumulate_data(standard_roi)) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) # Now accumulate detector events - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result = view.finalize() @@ -311,6 +411,42 @@ def test_accumulate_with_roi_produces_histogram( # Verify expected event count assert_roi_event_count(result, STANDARD_ROI_EVENTS) + def test_roi_current_has_time_coord( + self, + mock_rolling_view: RollingDetectorView, + sample_detector_events: sc.DataArray, + standard_roi: RectangleROI, + standard_toa_edges: TOAEdges, + ) -> None: + """Test that ROI current results have time coord matching detector view.""" + params = DetectorViewParams(toa_edges=standard_toa_edges) + view = DetectorView(params=params, detector_view=mock_rolling_view) + + # Configure ROI + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) + + # Accumulate detector events + view.accumulate( + {'detector': sample_detector_events}, start_time=2500, end_time=3000 + ) + + result = view.finalize() + + # Verify time coord on ROI current + assert 'time' in result['roi_current_0'].coords + assert result['roi_current_0'].coords['time'].value == 2500 + assert result['roi_current_0'].coords['time'].unit == 'ns' + + # Verify it matches detector view current time coord + assert ( + result['roi_current_0'].coords['time'] == result['current'].coords['time'] + ) + + # ROI cumulative should not have time coord + assert 'time' not in result['roi_cumulative_0'].coords + def test_roi_cumulative_accumulation( self, mock_rolling_view: RollingDetectorView, @@ -323,15 +459,21 @@ def test_roi_cumulative_accumulation( view = DetectorView(params=params, detector_view=mock_rolling_view) # Configure ROI - view.accumulate(roi_to_accumulate_data(standard_roi)) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) # First accumulation - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result1 = view.finalize() assert_roi_event_count(result1, STANDARD_ROI_EVENTS) # Second accumulation with same events - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() # Both should have same current counts @@ -350,14 +492,20 @@ def test_roi_published_only_on_update( view = DetectorView(params=params, detector_view=mock_rolling_view) # Configure ROI - view.accumulate(roi_to_accumulate_data(standard_roi)) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result1 = view.finalize() assert_roi_config_published(result1) # Second accumulation without ROI update - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() assert 'roi_rectangle' not in result2 # Not published again @@ -372,8 +520,12 @@ def test_clear_resets_roi_state( view = DetectorView(params=params, detector_view=mock_rolling_view) # Configure ROI and accumulate - view.accumulate(roi_to_accumulate_data(standard_roi)) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result1 = view.finalize() # Verify we accumulated some events @@ -382,11 +534,15 @@ def test_clear_resets_roi_state( # Clear should reset cumulative view.clear() - # After clear, ROI cumulative should be reset to zero + # After clear, accumulate new data + view.accumulate( + {'detector': sample_detector_events}, start_time=2000, end_time=3000 + ) result2 = view.finalize() assert_has_roi_results(result2) # ROI config still active - assert sc.sum(result2['roi_cumulative_0']).value == 0 - assert sc.sum(result2['roi_current_0']).value == 0 + # Cumulative should be reset (only contains events from after clear) + assert sc.sum(result2['roi_cumulative_0']).value == STANDARD_ROI_EVENTS + assert sc.sum(result2['roi_current_0']).value == STANDARD_ROI_EVENTS def test_roi_change_resets_cumulative( self, @@ -401,16 +557,24 @@ def test_roi_change_resets_cumulative( view = DetectorView(params=params, detector_view=mock_rolling_view) # Configure first ROI covering pixels 5, 6, 10 - view.accumulate(roi_to_accumulate_data(standard_roi)) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + roi_to_accumulate_data(standard_roi), start_time=1000, end_time=2000 + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result1 = view.finalize() assert_roi_event_count(result1, STANDARD_ROI_EVENTS, view='current') assert_roi_event_count(result1, STANDARD_ROI_EVENTS, view='cumulative') # Now change ROI to cover different pixels (1, 2, 5, 6) - view.accumulate(roi_to_accumulate_data(wide_roi)) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + roi_to_accumulate_data(wide_roi), start_time=1000, end_time=2000 + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() assert_roi_event_count(result2, WIDE_ROI_EVENTS, view='current') @@ -444,7 +608,9 @@ def test_accumulate_with_both_roi_and_detector_in_same_call( { **roi_to_accumulate_data(standard_roi), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result = view.finalize() @@ -481,7 +647,9 @@ def test_accumulate_both_then_detector_only( { 'roi': RectangleROI.to_concatenated_data_array({0: roi}), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result1 = view.finalize() @@ -490,7 +658,9 @@ def test_accumulate_both_then_detector_only( assert 'roi_rectangle' in result1 # Published on first update # Second call: detector data only (ROI should persist) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() assert sc.sum(result2['roi_current_0']).value == expected_events_in_roi @@ -509,7 +679,9 @@ def test_accumulate_detector_then_both_roi_and_detector( view = DetectorView(params=params, detector_view=mock_rolling_view) # First: just detector data (no ROI) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result1 = view.finalize() assert 'roi_cumulative' not in result1 @@ -528,7 +700,9 @@ def test_accumulate_detector_then_both_roi_and_detector( { 'roi': RectangleROI.to_concatenated_data_array({0: roi}), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result2 = view.finalize() @@ -562,7 +736,9 @@ def test_accumulate_roi_change_with_detector_in_same_call( { 'roi': RectangleROI.to_concatenated_data_array({0: roi1}), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result1 = view.finalize() @@ -583,7 +759,9 @@ def test_accumulate_roi_change_with_detector_in_same_call( { 'roi': RectangleROI.to_concatenated_data_array({0: roi2}), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result2 = view.finalize() @@ -598,20 +776,27 @@ class TestDetectorViewEdgeCases: """Edge cases and error conditions.""" def test_accumulate_empty_dict_does_nothing( - self, mock_rolling_view: RollingDetectorView + self, + mock_rolling_view: RollingDetectorView, + sample_detector_events: sc.DataArray, ) -> None: """Test that accumulate with empty dict returns early without error.""" params = DetectorViewParams() view = DetectorView(params=params, detector_view=mock_rolling_view) # Empty dict should not raise - consistent with roi-only behavior - view.accumulate({}) + view.accumulate({}, start_time=1000, end_time=2000) + + # Add detector data before finalize + view.accumulate( + {'detector': sample_detector_events}, start_time=2000, end_time=3000 + ) result = view.finalize() assert 'cumulative' in result assert 'current' in result - # No events accumulated - assert sc.sum(result['cumulative']).value == 0 + # Events from sample_detector_events + assert sc.sum(result['cumulative']).value == TOTAL_SAMPLE_EVENTS def test_accumulate_multiple_detector_keys_raises_error( self, @@ -631,13 +816,17 @@ def test_accumulate_multiple_detector_keys_raises_error( { 'detector1': sample_detector_events, 'detector2': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) def test_multiple_roi_updates_without_detector_data( - self, mock_rolling_view: RollingDetectorView + self, + mock_rolling_view: RollingDetectorView, + sample_detector_events: sc.DataArray, ) -> None: - """Test multiple ROI updates without any detector data.""" + """Test multiple ROI updates followed by detector data.""" params = DetectorViewParams() view = DetectorView(params=params, detector_view=mock_rolling_view) @@ -645,21 +834,36 @@ def test_multiple_roi_updates_without_detector_data( roi1 = make_rectangle_roi( x_min=5.0, x_max=25.0, y_min=5.0, y_max=25.0, x_unit='mm', y_unit='mm' ) - view.accumulate({'roi': RectangleROI.to_concatenated_data_array({0: roi1})}) + view.accumulate( + {'roi': RectangleROI.to_concatenated_data_array({0: roi1})}, + start_time=1000, + end_time=2000, + ) roi2 = make_rectangle_roi( x_min=10.0, x_max=20.0, y_min=10.0, y_max=20.0, x_unit='mm', y_unit='mm' ) - view.accumulate({'roi': RectangleROI.to_concatenated_data_array({0: roi2})}) + view.accumulate( + {'roi': RectangleROI.to_concatenated_data_array({0: roi2})}, + start_time=1000, + end_time=2000, + ) + + # Now add detector data + view.accumulate( + {'detector': sample_detector_events}, start_time=2000, end_time=3000 + ) result = view.finalize() - # Should have ROI configured but with zero events + # Should have ROI configured with accumulated events assert 'roi_cumulative_0' in result assert 'roi_current_0' in result assert 'roi_rectangle' in result - assert sc.sum(result['roi_cumulative_0']).value == 0 - assert sc.sum(result['roi_current_0']).value == 0 + # roi2 only covers pixel 10 (single event at x=20mm, y=20mm) + expected_events = 1 + assert sc.sum(result['roi_cumulative_0']).value == expected_events + assert sc.sum(result['roi_current_0']).value == expected_events def test_detector_data_with_no_events_in_roi( self, @@ -673,8 +877,12 @@ def test_detector_data_with_no_events_in_roi( view = DetectorView(params=params, detector_view=mock_rolling_view) # Configure ROI that covers only pixel 15 at (30mm, 30mm) - view.accumulate(roi_to_accumulate_data(corner_roi)) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + roi_to_accumulate_data(corner_roi), start_time=1000, end_time=2000 + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result = view.finalize() @@ -699,13 +907,17 @@ def test_roi_published_when_updated_with_detector_data( { 'roi': RectangleROI.to_concatenated_data_array({0: roi1}), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result1 = view.finalize() assert 'roi_rectangle' in result1 # Just detector (no ROI update) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() assert 'roi_rectangle' not in result2 @@ -717,7 +929,9 @@ def test_roi_published_when_updated_with_detector_data( { 'roi': RectangleROI.to_concatenated_data_array({0: roi2}), 'detector': sample_detector_events, - } + }, + start_time=1000, + end_time=2000, ) result3 = view.finalize() assert 'roi_rectangle' in result3 @@ -753,9 +967,13 @@ def test_roi_deletion_is_published( x_min=25.0, x_max=35.0, y_min=25.0, y_max=35.0, x_unit='mm', y_unit='mm' ) view.accumulate( - {'roi': RectangleROI.to_concatenated_data_array({0: roi0, 1: roi1})} + {'roi': RectangleROI.to_concatenated_data_array({0: roi0, 1: roi1})}, + start_time=1000, + end_time=2000, + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 ) - view.accumulate({'detector': sample_detector_events}) result1 = view.finalize() # Both ROIs should be published @@ -767,15 +985,23 @@ def test_roi_deletion_is_published( assert 1 in published_rois # Accumulate more data to build up cumulative - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result_no_change = view.finalize() assert 'roi_rectangle' not in result_no_change # ROI 0 should have accumulated events from both rounds assert sc.sum(result_no_change['roi_cumulative_0']).value == 2 * 3 # 6 events # Now delete ROI 1, keeping only ROI 0 - view.accumulate({'roi': RectangleROI.to_concatenated_data_array({0: roi0})}) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'roi': RectangleROI.to_concatenated_data_array({0: roi0})}, + start_time=1000, + end_time=2000, + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() # ROI config should be published to signal deletion @@ -820,13 +1046,19 @@ def test_roi_deletion_with_index_renumbering_clears_all( 'roi': RectangleROI.to_concatenated_data_array( {0: standard_roi, 1: corner_roi} ) - } + }, + start_time=1000, + end_time=2000, ) # Accumulate events twice to build up cumulative - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) view.finalize() - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result = view.finalize() # Verify both ROIs have accumulated @@ -836,9 +1068,13 @@ def test_roi_deletion_with_index_renumbering_clears_all( # Now simulate deleting ROI 0 in UI: ROI 1 gets renumbered to ROI 0 # From backend perspective: index 0 changes from standard_roi to corner_roi view.accumulate( - {'roi': RectangleROI.to_concatenated_data_array({0: corner_roi})} + {'roi': RectangleROI.to_concatenated_data_array({0: corner_roi})}, + start_time=1000, + end_time=2000, + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 ) - view.accumulate({'detector': sample_detector_events}) result2 = view.finalize() # CRITICAL: The renumbered ROI should be cleared @@ -866,22 +1102,36 @@ def test_unchanged_roi_resend_unnecessarily_resets_cumulative( roi0 = make_rectangle_roi( x_min=5.0, x_max=25.0, y_min=5.0, y_max=25.0, x_unit='mm', y_unit='mm' ) - view.accumulate({'roi': RectangleROI.to_concatenated_data_array({0: roi0})}) + view.accumulate( + {'roi': RectangleROI.to_concatenated_data_array({0: roi0})}, + start_time=1000, + end_time=2000, + ) # Accumulate data twice to build up cumulative - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) view.finalize() expected_events = 3 # pixels 5, 6, 10 in ROI - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result2 = view.finalize() # Cumulative should have doubled assert sc.sum(result2['roi_cumulative_0']).value == 2 * expected_events # Now resend the SAME ROI configuration (no actual change) - view.accumulate({'roi': RectangleROI.to_concatenated_data_array({0: roi0})}) - view.accumulate({'detector': sample_detector_events}) + view.accumulate( + {'roi': RectangleROI.to_concatenated_data_array({0: roi0})}, + start_time=1000, + end_time=2000, + ) + view.accumulate( + {'detector': sample_detector_events}, start_time=1000, end_time=2000 + ) result3 = view.finalize() # BUG: Current implementation resets cumulative even though ROI didn't change! diff --git a/tests/handlers/monitor_data_handler_test.py b/tests/handlers/monitor_data_handler_test.py index 1f2f93110..cb6e8f631 100644 --- a/tests/handlers/monitor_data_handler_test.py +++ b/tests/handlers/monitor_data_handler_test.py @@ -69,7 +69,7 @@ def test_initialization(self, edges): processor = MonitorStreamProcessor(edges) # Test public behavior: processor should be able to accumulate data toa_data = np.array([10e6]) # Test with minimal data - processor.accumulate({"det1": toa_data}) + processor.accumulate({"det1": toa_data}, start_time=1000, end_time=2000) result = processor.finalize() assert "cumulative" in result assert "current" in result @@ -82,7 +82,7 @@ def test_accumulate_numpy_array(self, processor): ) # 10, 25, 45, 75, 95 ms in ns data = {"detector1": toa_data} - processor.accumulate(data) + processor.accumulate(data, start_time=1000, end_time=2000) # Test by finalizing and checking the result result = processor.finalize() @@ -100,7 +100,7 @@ def test_accumulate_scipp_dataarray(self, processor): hist_data = sc.DataArray(data=counts, coords={"time": tof_coords}) data = {"detector1": hist_data} - processor.accumulate(data) + processor.accumulate(data, start_time=1000, end_time=2000) # Test by finalizing and checking the result result = processor.finalize() @@ -112,14 +112,14 @@ def test_accumulate_multiple_calls(self, processor): """Test multiple accumulate calls add data correctly.""" # First accumulation toa_data1 = np.array([10e6, 25e6]) # 10, 25 ms in ns - processor.accumulate({"det1": toa_data1}) + processor.accumulate({"det1": toa_data1}, start_time=1000, end_time=2000) first_result = processor.finalize() first_sum = first_result["current"].sum().value # Second accumulation - need new processor since finalize clears current processor2 = MonitorStreamProcessor(processor._edges) toa_data2 = np.array([35e6, 45e6]) # 35, 45 ms in ns - processor2.accumulate({"det1": toa_data2}) + processor2.accumulate({"det1": toa_data2}, start_time=1000, end_time=2000) second_result = processor2.finalize() second_sum = second_result["current"].sum().value @@ -132,36 +132,38 @@ def test_accumulate_wrong_number_of_items(self, processor): data = {"det1": np.array([10e6]), "det2": np.array([20e6])} with pytest.raises(ValueError, match="exactly one data item"): - processor.accumulate(data) + processor.accumulate(data, start_time=1000, end_time=2000) def test_finalize_first_time(self, processor): """Test finalize on first call.""" toa_data = np.array([10e6, 25e6, 45e6]) - processor.accumulate({"det1": toa_data}) + processor.accumulate({"det1": toa_data}, start_time=1000, end_time=2000) result = processor.finalize() assert "cumulative" in result assert "current" in result - assert_identical(result["cumulative"], result["current"]) - - # After finalize, we can finalize again without new data, since empty batches - # will be committed. - empty_result = processor.finalize() - assert empty_result["current"].sum().value == 0 - assert ( - empty_result["cumulative"].sum().value == result["cumulative"].sum().value - ) + # Check cumulative data (excluding time coord which current has) + assert_identical(result["cumulative"], result["current"].drop_coords("time")) + + # Verify time coordinate is present + assert "time" in result["current"].coords + assert result["current"].coords["time"].value == 1000 + assert result["current"].coords["time"].unit == "ns" def test_finalize_subsequent_calls(self, processor): """Test finalize accumulates over multiple calls.""" # First round - processor.accumulate({"det1": np.array([10e6, 25e6])}) + processor.accumulate( + {"det1": np.array([10e6, 25e6])}, start_time=1000, end_time=2000 + ) first_result = processor.finalize() first_cumulative_sum = first_result["cumulative"].sum().value # Second round - processor.accumulate({"det1": np.array([35e6, 45e6])}) + processor.accumulate( + {"det1": np.array([35e6, 45e6])}, start_time=1000, end_time=2000 + ) second_result = processor.finalize() second_cumulative_sum = second_result["cumulative"].sum().value @@ -174,9 +176,42 @@ def test_finalize_without_data(self, processor): with pytest.raises(ValueError, match="No data has been added"): processor.finalize() + def test_finalize_without_accumulate(self, processor): + """Test finalize raises error without accumulate since last finalize.""" + processor.accumulate( + {"det1": np.array([10e6, 25e6])}, start_time=1000, end_time=2000 + ) + processor.finalize() + + # After finalize, calling finalize again without accumulate should fail + with pytest.raises( + RuntimeError, + match="finalize called without any data accumulated via accumulate", + ): + processor.finalize() + + def test_time_coordinate_tracks_first_accumulate(self, processor): + """Test time coordinate uses start_time of the first accumulate call.""" + # First accumulate with start_time=1000 + processor.accumulate({"det1": np.array([10e6])}, start_time=1000, end_time=2000) + # Second accumulate with start_time=3000 (should be ignored) + processor.accumulate({"det1": np.array([20e6])}, start_time=3000, end_time=4000) + + result = processor.finalize() + + # Time coordinate should use the first start_time + assert result["current"].coords["time"].value == 1000 + + # After finalize, the next accumulate should set a new start_time + processor.accumulate({"det1": np.array([30e6])}, start_time=5000, end_time=6000) + result2 = processor.finalize() + assert result2["current"].coords["time"].value == 5000 + def test_clear(self, processor): """Test clear method resets processor state.""" - processor.accumulate({"det1": np.array([10e6, 25e6])}) + processor.accumulate( + {"det1": np.array([10e6, 25e6])}, start_time=1000, end_time=2000 + ) processor.finalize() processor.clear() @@ -192,7 +227,7 @@ def test_coordinate_unit_conversion(self, processor): counts = sc.ones(dims=["time"], shape=[9], unit="counts") hist_data = sc.DataArray(data=counts, coords={"time": tof_coords}) - processor.accumulate({"det1": hist_data}) + processor.accumulate({"det1": hist_data}, start_time=1000, end_time=2000) result = processor.finalize() assert "current" in result diff --git a/tests/handlers/stream_processor_workflow_test.py b/tests/handlers/stream_processor_workflow_test.py index 46723aace..7cd491501 100644 --- a/tests/handlers/stream_processor_workflow_test.py +++ b/tests/handlers/stream_processor_workflow_test.py @@ -80,11 +80,11 @@ def test_accumulate_and_finalize(self, base_workflow_with_context): ) # Set context data - workflow.accumulate({'context': Context(5)}) + workflow.accumulate({'context': Context(5)}, start_time=1000, end_time=2000) # Accumulate dynamic data - workflow.accumulate({'streamed': Streamed(10)}) - workflow.accumulate({'streamed': Streamed(20)}) + workflow.accumulate({'streamed': Streamed(10)}, start_time=1000, end_time=2000) + workflow.accumulate({'streamed': Streamed(20)}, start_time=1000, end_time=2000) # Finalize and check result result = workflow.finalize() @@ -102,15 +102,15 @@ def test_clear_workflow(self, base_workflow_with_context): ) # Accumulate some data - workflow.accumulate({'context': Context(5)}) - workflow.accumulate({'streamed': Streamed(10)}) + workflow.accumulate({'context': Context(5)}, start_time=1000, end_time=2000) + workflow.accumulate({'streamed': Streamed(10)}, start_time=1000, end_time=2000) # Clear and start fresh workflow.clear() # Set new context and data - workflow.accumulate({'context': Context(2)}) - workflow.accumulate({'streamed': Streamed(15)}) + workflow.accumulate({'context': Context(2)}, start_time=1000, end_time=2000) + workflow.accumulate({'streamed': Streamed(15)}, start_time=1000, end_time=2000) result = workflow.finalize() # Expected: context (2) * static (2) = 4, streamed: 15, final: 15 + 4 = 19 @@ -127,13 +127,13 @@ def test_partial_data_accumulation(self, base_workflow_with_context): ) # Accumulate with only context - workflow.accumulate({'context': Context(3)}) + workflow.accumulate({'context': Context(3)}, start_time=1000, end_time=2000) # Accumulate with only streamed data - workflow.accumulate({'streamed': Streamed(7)}) + workflow.accumulate({'streamed': Streamed(7)}, start_time=1000, end_time=2000) # Accumulate with unknown keys (should be ignored) - workflow.accumulate({'unknown': 42}) + workflow.accumulate({'unknown': 42}, start_time=1000, end_time=2000) result = workflow.finalize() # Expected: context (3) * static (2) = 6, streamed: 7, final: 7 + 6 = 13 @@ -149,8 +149,8 @@ def test_target_keys_with_simplified_names(self, base_workflow_with_context): accumulators=(ProcessedStreamed,), ) - workflow.accumulate({'context': Context(4)}) - workflow.accumulate({'streamed': Streamed(5)}) + workflow.accumulate({'context': Context(4)}, start_time=1000, end_time=2000) + workflow.accumulate({'streamed': Streamed(5)}, start_time=1000, end_time=2000) result = workflow.finalize() # Expected: context (4) * static (2) = 8, streamed: 5, final: 5 + 8 = 13 @@ -167,7 +167,7 @@ def test_no_context_keys(self, base_workflow_no_context): ) # Only accumulate dynamic data - workflow.accumulate({'streamed': Streamed(25)}) + workflow.accumulate({'streamed': Streamed(25)}, start_time=1000, end_time=2000) result = workflow.finalize() # Expected: streamed (25) + static (2) = 27 diff --git a/tests/integration/helpers_test.py b/tests/integration/helpers_test.py index 670a25f7c..aa4ffc59d 100644 --- a/tests/integration/helpers_test.py +++ b/tests/integration/helpers_test.py @@ -11,6 +11,7 @@ from typing import Any import pytest +import scipp as sc from ess.livedata.config.workflow_spec import JobId, JobNumber, ResultKey, WorkflowId from ess.livedata.core.job_manager import JobState, JobStatus @@ -24,6 +25,18 @@ ) +def make_test_result(value: str) -> sc.DataArray: + """Create a scipp DataArray representing a result value. + + The returned DataArray has a time dimension so that LatestValueExtractor + will extract the scalar value from it. + """ + return sc.DataArray( + sc.array(dims=['time'], values=[value], unit='dimensionless'), + coords={'time': sc.array(dims=['time'], values=[0.0], unit='s')}, + ) + + def make_workflow_id(name: str, version: int = 1) -> WorkflowId: """Helper to create WorkflowId for tests.""" return WorkflowId( @@ -121,7 +134,7 @@ def simulate_data_arrival(): job_id=job_id, output_name='output1', ) - ] = 'data' + ] = make_test_result('data') backend.on_update_callbacks.append(simulate_data_arrival) @@ -131,7 +144,7 @@ def simulate_data_arrival(): # Should return dict mapping JobId to job_data assert job_id in result assert 'source1' in result[job_id] - assert result[job_id]['source1']['output1'] == 'data' + assert result[job_id]['source1']['output1'].value == 'data' def test_succeeds_when_data_arrives_for_all_jobs( self, job_service: JobService, data_service: DataService[ResultKey, Any] @@ -155,7 +168,7 @@ def simulate_data_arrival(): job_id=job_id, output_name='output1', ) - ] = 'data' + ] = make_test_result('data') backend.on_update_callbacks.append(simulate_data_arrival) @@ -182,7 +195,7 @@ def test_must_not_succeed_if_only_some_jobs_have_data( job_id=job_ids[0], output_name='output1', ) - ] = 'data' + ] = make_test_result('data') with pytest.raises(WaitTimeout): wait_for_job_data(backend, job_ids, timeout=0.3, poll_interval=0.05) @@ -204,7 +217,7 @@ def test_must_not_succeed_for_wrong_job( job_id=job_id_other, output_name='output1', ) - ] = 'data' + ] = make_test_result('data') with pytest.raises(WaitTimeout): wait_for_job_data( diff --git a/tests/services/timeseries_test.py b/tests/services/timeseries_test.py index 19c03531e..5a2b643d0 100644 --- a/tests/services/timeseries_test.py +++ b/tests/services/timeseries_test.py @@ -67,7 +67,7 @@ def test_updates_are_published_immediately( app.publish_log_message(source_name=source_name, time=1, value=1.5) service.step() - # Each workflow call returns one result, cumulative + # Each workflow call returns only new data since last finalize (delta) assert len(sink.messages) == 1 assert sink.messages[-1].value.values.sum() == 1.5 # No data -> no data published @@ -78,4 +78,5 @@ def test_updates_are_published_immediately( app.publish_log_message(source_name=source_name, time=1.0001, value=0.5) service.step() assert len(sink.messages) == 2 - assert sink.messages[-1].value.values.sum() == 2.0 + # Expect only the new data point (delta), not cumulative + assert sink.messages[-1].value.values.sum() == 0.5