Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/scippneutron/chopper/disk_chopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ def from_nexus(
)
return DiskChopper(
axle_position=chopper['position'],
frequency=_get_1d_variable(chopper, 'rotation_speed'),
beam_position=_get_1d_variable(chopper, 'beam_position'),
phase=_get_1d_variable(chopper, 'phase'),
frequency=_get_0d_variable(chopper, 'rotation_speed'),
beam_position=_get_0d_variable(chopper, 'beam_position'),
phase=_get_0d_variable(chopper, 'phase'),
slit_height=chopper.get('slit_height'),
radius=chopper.get('radius'),
**_get_edges_from_nexus(chopper),
Expand Down Expand Up @@ -557,6 +557,10 @@ def _source_phase_factor(self, pulse_frequency: sc.Variable) -> int:
# of the slits, so use `max` here:
return round(max(quot.value, 1))

def as_dict(self) -> dict[str, Any]:
"""Return the DiskChopper fields as a dictionary."""
return dataclasses.asdict(self)


def _field_eq(a: Any, b: Any) -> bool:
if isinstance(a, sc.Variable | sc.DataArray):
Expand Down Expand Up @@ -627,13 +631,7 @@ def _get_edges_from_nexus(
}


def _len_or_1(x: sc.Variable) -> int:
if x.ndim == 0:
return 1
return len(x)


def _get_1d_variable(
def _get_0d_variable(
dg: Mapping[str, sc.Variable | sc.DataArray], name: str
) -> sc.Variable:
if (val := dg.get(name)) is None:
Expand All @@ -642,9 +640,11 @@ def _get_1d_variable(
msg = (
"Chopper field '{name}' must be a scalar variable, {got}. "
"See the chopper user-guide for more information: "
"https://scipp.github.io/scippneutron/user-guide/chopper/pre-processing.html"
"https://scipp.github.io/scippneutron/user-guide/chopper/processing-nexus-choppers.html"
)

if isinstance(val, sc.DataArray):
val = val.data
if not isinstance(val, sc.Variable):
raise TypeError(msg.format(name=name, got=f'got a {type(val)}'))
if val.ndim != 0:
Expand Down
Loading