Skip to content

Commit 12aa4e8

Browse files
SimonHeybrockclaude
andcommitted
Refactor: Move create_extractors_from_params to plot_params and inline redundant wrapper
- Move create_extractors_from_params() from extractors.py to plot_params.py where parameter types are defined. This eliminates awkward local import of WindowMode and improves cohesion. - Move TestCreateExtractorsFromParams tests to new plot_params_test.py file - Inline PlottingController._create_extractors() as it's a thin wrapper with single call site; the 2-line function call is self-explanatory - Remove unused UpdateExtractor import from plotting_controller.py - Move UpdateExtractor from TYPE_CHECKING to regular import in temporal_buffer_manager.py as it's used at runtime Original request: create_extractors_from_params might belong to a different file - see awkward import handling. Can you find a better place and move it as well as its tests? Follow-up: The PlottingController._create_extractors methods feels redundant, just inline? 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 3aa1329 commit 12aa4e8

File tree

7 files changed

+216
-225
lines changed

7 files changed

+216
-225
lines changed

src/ess/livedata/dashboard/extractors.py

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,12 @@
44

55
from abc import ABC, abstractmethod
66
from collections.abc import Callable
7-
from typing import TYPE_CHECKING, Any
7+
from typing import Any
88

99
import scipp as sc
1010

1111
from .plot_params import WindowAggregation
1212

13-
if TYPE_CHECKING:
14-
from ess.livedata.config.workflow_spec import ResultKey
15-
16-
from .plot_params import WindowParams
17-
from .plotting import PlotterSpec
18-
1913

2014
class UpdateExtractor(ABC):
2115
"""Extracts a specific view of buffered data."""
@@ -169,52 +163,3 @@ def extract(self, data: sc.DataArray) -> Any:
169163
raise ValueError(f"Unknown aggregation method: {self._aggregation}")
170164

171165
return self._aggregator(windowed_data, self._concat_dim)
172-
173-
174-
def create_extractors_from_params(
175-
keys: list[ResultKey],
176-
window: WindowParams | None,
177-
spec: PlotterSpec | None = None,
178-
) -> dict[ResultKey, UpdateExtractor]:
179-
"""
180-
Create extractors based on plotter spec and window configuration.
181-
182-
Parameters
183-
----------
184-
keys:
185-
Result keys to create extractors for.
186-
window:
187-
Window parameters for extraction mode and aggregation.
188-
If None, falls back to LatestValueExtractor.
189-
spec:
190-
Optional plotter specification. If provided and contains a required
191-
extractor, that extractor type is used.
192-
193-
Returns
194-
-------
195-
:
196-
Dictionary mapping result keys to extractor instances.
197-
"""
198-
# Avoid circular import by importing here
199-
from .plot_params import WindowMode
200-
201-
if spec is not None and spec.data_requirements.required_extractor is not None:
202-
# Plotter requires specific extractor (e.g., TimeSeriesPlotter)
203-
extractor_type = spec.data_requirements.required_extractor
204-
return {key: extractor_type() for key in keys}
205-
206-
# No fixed requirement - check if window params provided
207-
if window is not None:
208-
if window.mode == WindowMode.latest:
209-
return {key: LatestValueExtractor() for key in keys}
210-
else: # mode == WindowMode.window
211-
return {
212-
key: WindowAggregatingExtractor(
213-
window_duration_seconds=window.window_duration_seconds,
214-
aggregation=window.aggregation,
215-
)
216-
for key in keys
217-
}
218-
219-
# Fallback to latest value extractor
220-
return {key: LatestValueExtractor() for key in keys}

src/ess/livedata/dashboard/plot_params.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
33
"""Param models for configuring plotters via widgets."""
44

5+
from __future__ import annotations
6+
57
import enum
68
from enum import StrEnum
9+
from typing import TYPE_CHECKING
710

811
import pydantic
912

1013
from ..config.roi_names import get_roi_mapper
1114

15+
if TYPE_CHECKING:
16+
from ess.livedata.config.workflow_spec import ResultKey
17+
18+
from .extractors import UpdateExtractor
19+
from .plotting import PlotterSpec
20+
1221

1322
def _get_default_max_roi_count() -> int:
1423
"""Get the default maximum ROI count from the mapper configuration."""
@@ -226,3 +235,55 @@ class PlotParamsROIDetector(PlotParams2d):
226235
default_factory=ROIOptions,
227236
description="Options for ROI selection and display.",
228237
)
238+
239+
240+
def create_extractors_from_params(
241+
keys: list[ResultKey],
242+
window: WindowParams | None,
243+
spec: PlotterSpec | None = None,
244+
) -> dict[ResultKey, UpdateExtractor]:
245+
"""
246+
Create extractors based on plotter spec and window configuration.
247+
248+
Parameters
249+
----------
250+
keys:
251+
Result keys to create extractors for.
252+
window:
253+
Window parameters for extraction mode and aggregation.
254+
If None, falls back to LatestValueExtractor.
255+
spec:
256+
Optional plotter specification. If provided and contains a required
257+
extractor, that extractor type is used.
258+
259+
Returns
260+
-------
261+
:
262+
Dictionary mapping result keys to extractor instances.
263+
"""
264+
# Import here to avoid circular imports at module level
265+
from .extractors import (
266+
LatestValueExtractor,
267+
WindowAggregatingExtractor,
268+
)
269+
270+
if spec is not None and spec.data_requirements.required_extractor is not None:
271+
# Plotter requires specific extractor (e.g., TimeSeriesPlotter)
272+
extractor_type = spec.data_requirements.required_extractor
273+
return {key: extractor_type() for key in keys}
274+
275+
# No fixed requirement - check if window params provided
276+
if window is not None:
277+
if window.mode == WindowMode.latest:
278+
return {key: LatestValueExtractor() for key in keys}
279+
else: # mode == WindowMode.window
280+
return {
281+
key: WindowAggregatingExtractor(
282+
window_duration_seconds=window.window_duration_seconds,
283+
aggregation=window.aggregation,
284+
)
285+
for key in keys
286+
}
287+
288+
# Fallback to latest value extractor
289+
return {key: LatestValueExtractor() for key in keys}

src/ess/livedata/dashboard/plotting_controller.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818

1919
from .config_store import ConfigStore
2020
from .configuration_adapter import ConfigurationState
21-
from .extractors import (
22-
UpdateExtractor,
23-
create_extractors_from_params,
24-
)
2521
from .job_service import JobService
22+
from .plot_params import create_extractors_from_params
2623
from .plotting import PlotterSpec, plotter_registry
2724
from .roi_detector_plot_factory import ROIDetectorPlotFactory
2825
from .roi_publisher import ROIPublisher
@@ -232,32 +229,6 @@ def _save_plotting_config(
232229
)
233230
self._config_store[plotter_id] = config_state.model_dump()
234231

235-
def _create_extractors(
236-
self,
237-
keys: list[ResultKey],
238-
spec: PlotterSpec,
239-
params: pydantic.BaseModel,
240-
) -> dict[ResultKey, UpdateExtractor]:
241-
"""
242-
Create extractors based on plotter requirements and parameters.
243-
244-
Parameters
245-
----------
246-
keys:
247-
Result keys to create extractors for.
248-
spec:
249-
Plotter specification containing data requirements.
250-
params:
251-
Plotter parameters potentially containing window configuration.
252-
253-
Returns
254-
-------
255-
:
256-
Dictionary mapping result keys to extractor instances.
257-
"""
258-
window = getattr(params, 'window', None)
259-
return create_extractors_from_params(keys, window, spec)
260-
261232
def create_plot(
262233
self,
263234
job_number: JobNumber,
@@ -325,7 +296,8 @@ def create_plot(
325296

326297
# Create extractors based on plotter requirements and params
327298
spec = plotter_registry.get_spec(plot_name)
328-
extractors = self._create_extractors(keys, spec, params)
299+
window = getattr(params, 'window', None)
300+
extractors = create_extractors_from_params(keys, window, spec)
329301

330302
pipe = self._stream_manager.make_merging_stream(extractors)
331303
plotter = plotter_registry.create_plotter(plot_name, params=params)

src/ess/livedata/dashboard/roi_detector_plot_factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
DataSubscriber,
2020
MergingStreamAssembler,
2121
)
22-
from .extractors import LatestValueExtractor, create_extractors_from_params
23-
from .plot_params import LayoutParams, PlotParamsROIDetector
22+
from .extractors import LatestValueExtractor
23+
from .plot_params import (
24+
LayoutParams,
25+
PlotParamsROIDetector,
26+
create_extractors_from_params,
27+
)
2428
from .plots import ImagePlotter, LinePlotter, PlotAspect, PlotAspectType
2529
from .roi_publisher import ROIPublisher
2630
from .stream_manager import StreamManager

src/ess/livedata/dashboard/temporal_buffer_manager.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
import logging
88
from collections.abc import Hashable, Iterator, Mapping
99
from dataclasses import dataclass, field
10-
from typing import TYPE_CHECKING, Generic, TypeVar
10+
from typing import Generic, TypeVar
1111

1212
import scipp as sc
1313

14-
from .extractors import LatestValueExtractor
14+
from .extractors import LatestValueExtractor, UpdateExtractor
1515
from .temporal_buffers import BufferProtocol, SingleValueBuffer, TemporalBuffer
1616

17-
if TYPE_CHECKING:
18-
from .extractors import UpdateExtractor
19-
2017
logger = logging.getLogger(__name__)
2118

2219
K = TypeVar('K', bound=Hashable)

tests/dashboard/extractors_test.py

Lines changed: 1 addition & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
33
from __future__ import annotations
44

5-
from unittest.mock import Mock
6-
75
import pytest
86
import scipp as sc
97

@@ -12,13 +10,8 @@
1210
LatestValueExtractor,
1311
UpdateExtractor,
1412
WindowAggregatingExtractor,
15-
create_extractors_from_params,
16-
)
17-
from ess.livedata.dashboard.plot_params import (
18-
WindowAggregation,
19-
WindowMode,
20-
WindowParams,
2113
)
14+
from ess.livedata.dashboard.plot_params import WindowAggregation
2215

2316

2417
class TestLatestValueExtractor:
@@ -505,129 +498,6 @@ def test_consistent_frame_count_with_perfect_timing(self):
505498
assert sc.allclose(result.data, expected_sum)
506499

507500

508-
class TestCreateExtractorsFromParams:
509-
"""Tests for create_extractors_from_params factory function."""
510-
511-
def test_fallback_to_latest_value_when_no_params(self):
512-
"""Test fallback to LatestValueExtractor when no window params provided."""
513-
keys = ['key1', 'key2']
514-
515-
extractors = create_extractors_from_params(keys=keys, window=None, spec=None)
516-
517-
assert len(extractors) == 2
518-
assert all(isinstance(ext, LatestValueExtractor) for ext in extractors.values())
519-
assert set(extractors.keys()) == {'key1', 'key2'}
520-
521-
def test_create_latest_value_extractors_with_window_mode_latest(self):
522-
"""Test creation of LatestValueExtractor when window mode is 'latest'."""
523-
keys = ['key1']
524-
window = WindowParams(mode=WindowMode.latest)
525-
526-
extractors = create_extractors_from_params(keys=keys, window=window, spec=None)
527-
528-
assert len(extractors) == 1
529-
assert isinstance(extractors['key1'], LatestValueExtractor)
530-
531-
def test_create_window_aggregating_extractors_with_window_mode_window(self):
532-
"""Test creation of WindowAggregatingExtractor when window mode is 'window'."""
533-
keys = ['key1', 'key2']
534-
window = WindowParams(
535-
mode=WindowMode.window,
536-
window_duration_seconds=5.0,
537-
aggregation=WindowAggregation.nansum,
538-
)
539-
540-
extractors = create_extractors_from_params(keys=keys, window=window, spec=None)
541-
542-
assert len(extractors) == 2
543-
assert all(
544-
isinstance(ext, WindowAggregatingExtractor) for ext in extractors.values()
545-
)
546-
547-
# Verify behavior through public interface
548-
extractor = extractors['key1']
549-
assert extractor.get_required_timespan() == 5.0
550-
551-
def test_spec_required_extractor_overrides_window_params(self):
552-
"""Test that plotter spec's required extractor overrides window params."""
553-
keys = ['key1', 'key2']
554-
window = WindowParams(mode=WindowMode.latest)
555-
556-
# Create mock spec with required extractor
557-
spec = Mock()
558-
spec.data_requirements.required_extractor = FullHistoryExtractor
559-
560-
extractors = create_extractors_from_params(keys=keys, window=window, spec=spec)
561-
562-
# Should use FullHistoryExtractor despite window params
563-
assert len(extractors) == 2
564-
assert all(isinstance(ext, FullHistoryExtractor) for ext in extractors.values())
565-
566-
def test_spec_with_no_required_extractor_uses_window_params(self):
567-
"""Test that window params are used when spec has no required extractor."""
568-
keys = ['key1']
569-
window = WindowParams(mode=WindowMode.window, window_duration_seconds=3.0)
570-
571-
# Create mock spec without required extractor
572-
spec = Mock()
573-
spec.data_requirements.required_extractor = None
574-
575-
extractors = create_extractors_from_params(keys=keys, window=window, spec=spec)
576-
577-
assert isinstance(extractors['key1'], WindowAggregatingExtractor)
578-
assert extractors['key1'].get_required_timespan() == 3.0
579-
580-
def test_creates_extractors_for_all_keys(self):
581-
"""Test that extractors are created for all provided keys."""
582-
keys = ['result1', 'result2', 'result3']
583-
window = WindowParams(mode=WindowMode.latest)
584-
585-
extractors = create_extractors_from_params(keys=keys, window=window, spec=None)
586-
587-
assert len(extractors) == 3
588-
assert set(extractors.keys()) == {'result1', 'result2', 'result3'}
589-
assert all(isinstance(ext, LatestValueExtractor) for ext in extractors.values())
590-
591-
def test_empty_keys_returns_empty_dict(self):
592-
"""Test that empty keys list returns empty extractors dict."""
593-
keys = []
594-
window = WindowParams(mode=WindowMode.latest)
595-
596-
extractors = create_extractors_from_params(keys=keys, window=window, spec=None)
597-
598-
assert extractors == {}
599-
600-
def test_window_aggregation_parameters_passed_correctly(self):
601-
"""Test that window aggregation parameters result in correct behavior."""
602-
keys = ['key1']
603-
window = WindowParams(
604-
mode=WindowMode.window,
605-
window_duration_seconds=10.5,
606-
aggregation=WindowAggregation.mean,
607-
)
608-
609-
extractors = create_extractors_from_params(keys=keys, window=window, spec=None)
610-
611-
extractor = extractors['key1']
612-
assert isinstance(extractor, WindowAggregatingExtractor)
613-
# Verify timespan through public interface
614-
assert extractor.get_required_timespan() == 10.5
615-
616-
# Verify aggregation behavior by extracting data
617-
data = sc.DataArray(
618-
sc.array(dims=['time', 'x'], values=[[2, 4], [4, 6]], unit='m'),
619-
coords={
620-
'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s'),
621-
'x': sc.arange('x', 2, unit='m'),
622-
},
623-
)
624-
result = extractor.extract(data)
625-
# Mean of [2, 4] and [4, 6] = [3, 5], verifying mean aggregation was used
626-
assert sc.allclose(
627-
result.data, sc.array(dims=['x'], values=[3.0, 5.0], unit='m')
628-
)
629-
630-
631501
class TestUpdateExtractorInterface:
632502
"""Tests for UpdateExtractor abstract interface."""
633503

0 commit comments

Comments
 (0)