Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/ess/reduce/time_of_flight/eto_to_tof.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def __call__(


def _time_of_flight_data_histogram(
da: sc.DataArray, lookup: sc.DataArray, ltotal: sc.Variable
da: sc.DataArray, lookup: sc.DataGroup, ltotal: sc.Variable
) -> sc.DataArray:
# In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files,
# it may be called 'tof' or 'frame_time'.
key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof", "frame_time"}))
raw_eto = da.coords[key].to(dtype=float, copy=False)
eto_unit = raw_eto.unit
pulse_period = lookup.coords["pulse_period"].to(unit=eto_unit)
pulse_period = lookup["pulse_period"].to(unit=eto_unit)

# In histogram mode, because there is a wrap around at the end of the pulse, we
# need to insert a bin edge at that exact location to avoid having the last bin
Expand All @@ -117,7 +117,9 @@ def _time_of_flight_data_histogram(
etos = rebinned.coords[key]

# Create linear interpolator
interp = TofInterpolator(lookup, distance_unit=ltotal.unit, time_unit=eto_unit)
interp = TofInterpolator(
lookup["data"], distance_unit=ltotal.unit, time_unit=eto_unit
)

# Compute time-of-flight of the bin edges using the interpolator
tofs = interp(
Expand Down Expand Up @@ -199,7 +201,7 @@ def _guess_pulse_stride_offset(

def _prepare_tof_interpolation_inputs(
da: sc.DataArray,
lookup: sc.DataArray,
lookup: sc.DataGroup,
ltotal: sc.Variable,
pulse_stride_offset: int | None,
) -> dict:
Expand Down Expand Up @@ -227,15 +229,17 @@ def _prepare_tof_interpolation_inputs(
eto_unit = elem_unit(etos)

# Create linear interpolator
interp = TofInterpolator(lookup, distance_unit=ltotal.unit, time_unit=eto_unit)
interp = TofInterpolator(
lookup["data"], distance_unit=ltotal.unit, time_unit=eto_unit
)

# Operate on events (broadcast distances to all events)
ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"]
etos = etos.bins.constituents["data"]

pulse_index = None
pulse_period = lookup.coords["pulse_period"].to(unit=eto_unit)
pulse_stride = lookup.coords["pulse_stride"].value
pulse_period = lookup["pulse_period"].to(unit=eto_unit)
pulse_stride = lookup["pulse_stride"].value

if pulse_stride > 1:
# Compute a pulse index for every event: it is the index of the pulse within a
Expand Down Expand Up @@ -291,7 +295,7 @@ def _prepare_tof_interpolation_inputs(

def _time_of_flight_data_events(
da: sc.DataArray,
lookup: sc.DataArray,
lookup: sc.DataGroup,
ltotal: sc.Variable,
pulse_stride_offset: int | None,
) -> sc.DataArray:
Expand Down Expand Up @@ -375,7 +379,7 @@ def monitor_ltotal_from_straight_line_approximation(

def _compute_tof_data(
da: sc.DataArray,
lookup: sc.DataArray,
lookup: sc.DataGroup,
ltotal: sc.Variable,
pulse_stride_offset: int,
) -> sc.DataArray:
Expand Down Expand Up @@ -503,7 +507,7 @@ def detector_time_of_arrival_data(
result = detector_data.bins.assign_coords(
toa=sc.bins(**parts, validate_indices=False)
)
return result
return ToaDetector[RunType](result)


def providers() -> tuple[Callable]:
Expand Down
19 changes: 6 additions & 13 deletions src/ess/reduce/time_of_flight/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,13 @@ def __init__(
self.source = source(pulses=self.npulses)

# Convert the choppers to tof.Chopper
self.choppers = [
tof_pkg.Chopper(
frequency=abs(ch.frequency),
direction=tof_pkg.AntiClockwise
if (ch.frequency.value > 0.0)
else tof_pkg.Clockwise,
open=ch.slit_begin,
close=ch.slit_end,
phase=ch.phase if ch.frequency.value > 0.0 else -ch.phase,
distance=sc.norm(ch.axle_position - source_position),
name=name,
self.choppers = []
for name, ch in choppers.items():
chop = tof_pkg.Chopper.from_diskchopper(ch, name=name)
chop.distance = sc.norm(
ch.axle_position - source_position.to(unit=ch.axle_position.unit)
)
for name, ch in choppers.items()
]
self.choppers.append(chop)

# Add detectors
self.monitors = [
Expand Down
46 changes: 30 additions & 16 deletions src/ess/reduce/time_of_flight/lut.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ class SimulationResults:
For a ``tof`` simulation, this is just the position of the detector where the
events are recorded. For a ``McStas`` simulation, this is the distance between
the source and the event monitor.
choppers:
The parameters of the choppers used in the simulation (if any).
"""

time_of_arrival: sc.Variable
speed: sc.Variable
wavelength: sc.Variable
weight: sc.Variable
distance: sc.Variable
choppers: DiskChoppers[AnyRun] | None = None


NumberOfSimulatedNeutrons = NewType("NumberOfSimulatedNeutrons", int)
Expand Down Expand Up @@ -376,7 +379,25 @@ def make_tof_lookup_table(
# In-place masking for better performance
_mask_large_uncertainty(table, error_threshold)

return TimeOfFlightLookupTable(table)
out = sc.DataGroup(
{
"data": table,
"pulse_period": pulse_period,
"pulse_stride": sc.scalar(pulse_stride, unit=None),
"distance_resolution": table.coords["distance"][1]
- table.coords["distance"][0],
"time_resolution": table.coords["event_time_offset"][1]
- table.coords["event_time_offset"][0],
"error_threshold": sc.scalar(error_threshold),
}
)

if simulation.choppers is not None:
out['choppers'] = sc.DataGroup(
{k: sc.DataGroup(ch.as_dict()) for k, ch in simulation.choppers.items()}
)

return TimeOfFlightLookupTable(out)


def simulate_chopper_cascade_using_tof(
Expand Down Expand Up @@ -412,22 +433,14 @@ def simulate_chopper_cascade_using_tof(
"""
import tof

tof_choppers = [
tof.Chopper(
frequency=abs(ch.frequency),
direction=tof.AntiClockwise
if (ch.frequency.value > 0.0)
else tof.Clockwise,
open=ch.slit_begin,
close=ch.slit_end,
phase=ch.phase if ch.frequency.value > 0.0 else -ch.phase,
distance=sc.norm(
ch.axle_position - source_position.to(unit=ch.axle_position.unit)
),
name=name,
tof_choppers = []
for name, ch in choppers.items():
chop = tof.Chopper.from_diskchopper(ch, name=name)
chop.distance = sc.norm(
ch.axle_position - source_position.to(unit=ch.axle_position.unit)
)
for name, ch in choppers.items()
]
tof_choppers.append(chop)

source = tof.Source(
facility=facility, neutrons=neutrons, pulses=pulse_stride, seed=seed
)
Expand All @@ -454,6 +467,7 @@ def simulate_chopper_cascade_using_tof(
wavelength=events.coords["wavelength"],
weight=events.data,
distance=furthest_chopper.distance,
choppers=choppers,
)


Expand Down
2 changes: 1 addition & 1 deletion src/ess/reduce/time_of_flight/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""Filename of the time-of-flight lookup table."""


TimeOfFlightLookupTable = NewType("TimeOfFlightLookupTable", sc.DataArray)
TimeOfFlightLookupTable = NewType("TimeOfFlightLookupTable", sc.DataGroup)
"""
Lookup table giving time-of-flight as a function of distance and time of arrival.
"""
Expand Down
38 changes: 19 additions & 19 deletions tests/time_of_flight/lut_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def test_lut_workflow_computes_table():

table = wf.compute(time_of_flight.TimeOfFlightLookupTable)

assert table.coords['distance'].min() < lmin
assert table.coords['distance'].max() > lmax
assert table.coords['event_time_offset'].max() == sc.scalar(1 / 14, unit='s').to(
unit=table.coords['event_time_offset'].unit
)
assert sc.isclose(table.coords['distance_resolution'], dres)
assert table['data'].coords['distance'].min() < lmin
assert table['data'].coords['distance'].max() > lmax
assert table['data'].coords['event_time_offset'].max() == sc.scalar(
1 / 14, unit='s'
).to(unit=table['data'].coords['event_time_offset'].unit)
assert sc.isclose(table['distance_resolution'], dres)
# Note that the time resolution is not exactly preserved since we want the table to
# span exactly the frame period.
assert sc.isclose(table.coords['time_resolution'], tres, rtol=sc.scalar(0.01))
assert sc.isclose(table['time_resolution'], tres, rtol=sc.scalar(0.01))


def test_lut_workflow_computes_table_in_chunks():
Expand All @@ -58,15 +58,15 @@ def test_lut_workflow_computes_table_in_chunks():

table = wf.compute(time_of_flight.TimeOfFlightLookupTable)

assert table.coords['distance'].min() < lmin
assert table.coords['distance'].max() > lmax
assert table.coords['event_time_offset'].max() == sc.scalar(1 / 14, unit='s').to(
unit=table.coords['event_time_offset'].unit
)
assert sc.isclose(table.coords['distance_resolution'], dres)
assert table['data'].coords['distance'].min() < lmin
assert table['data'].coords['distance'].max() > lmax
assert table['data'].coords['event_time_offset'].max() == sc.scalar(
1 / 14, unit='s'
).to(unit=table['data'].coords['event_time_offset'].unit)
assert sc.isclose(table['distance_resolution'], dres)
# Note that the time resolution is not exactly preserved since we want the table to
# span exactly the frame period.
assert sc.isclose(table.coords['time_resolution'], tres, rtol=sc.scalar(0.01))
assert sc.isclose(table['time_resolution'], tres, rtol=sc.scalar(0.01))


def test_lut_workflow_pulse_skipping():
Expand All @@ -87,9 +87,9 @@ def test_lut_workflow_pulse_skipping():

table = wf.compute(time_of_flight.TimeOfFlightLookupTable)

assert table.coords['event_time_offset'].max() == 2 * sc.scalar(
assert table['data'].coords['event_time_offset'].max() == 2 * sc.scalar(
1 / 14, unit='s'
).to(unit=table.coords['event_time_offset'].unit)
).to(unit=table['data'].coords['event_time_offset'].unit)


def test_lut_workflow_non_exact_distance_range():
Expand All @@ -110,6 +110,6 @@ def test_lut_workflow_non_exact_distance_range():

table = wf.compute(time_of_flight.TimeOfFlightLookupTable)

assert table.coords['distance'].min() < lmin
assert table.coords['distance'].max() > lmax
assert sc.isclose(table.coords['distance_resolution'], dres)
assert table['data'].coords['distance'].min() < lmin
assert table['data'].coords['distance'].max() > lmax
assert sc.isclose(table['distance_resolution'], dres)
6 changes: 5 additions & 1 deletion tests/time_of_flight/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def test_TofLookupTableWorkflow_can_compute_tof_lut():
)
wf[time_of_flight.SourcePosition] = fakes.source_position()
lut = wf.compute(time_of_flight.TimeOfFlightLookupTable)
assert isinstance(lut, sc.DataArray)
assert "data" in lut
assert "distance_resolution" in lut
assert "time_resolution" in lut
assert "pulse_stride" in lut
assert "pulse_period" in lut


def test_GenericTofWorkflow_with_tof_lut_from_tof_simulation(
Expand Down
Loading