Skip to content

Commit b46edac

Browse files
authored
Merge pull request #276 from scipp/toa-workflow
Add provider to tof workflow to compute toa
2 parents 3f9d087 + 3f69948 commit b46edac

File tree

4 files changed

+179
-9
lines changed

4 files changed

+179
-9
lines changed

src/ess/reduce/time_of_flight/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
PulseStrideOffset,
2929
TimeOfFlightLookupTable,
3030
TimeOfFlightLookupTableFilename,
31+
ToaDetector,
3132
TofDetector,
3233
TofMonitor,
3334
)
@@ -51,6 +52,7 @@
5152
"TimeOfFlightLookupTable",
5253
"TimeOfFlightLookupTableFilename",
5354
"TimeResolution",
55+
"ToaDetector",
5456
"TofDetector",
5557
"TofLookupTableWorkflow",
5658
"TofMonitor",

src/ess/reduce/time_of_flight/eto_to_tof.py

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
MonitorLtotal,
3737
PulseStrideOffset,
3838
TimeOfFlightLookupTable,
39+
ToaDetector,
3940
TofDetector,
4041
TofMonitor,
4142
)
@@ -196,12 +197,32 @@ def _guess_pulse_stride_offset(
196197
return sorted(tofs, key=lambda x: sc.isnan(tofs[x]).sum())[0]
197198

198199

199-
def _time_of_flight_data_events(
200+
def _prepare_tof_interpolation_inputs(
200201
da: sc.DataArray,
201202
lookup: sc.DataArray,
202203
ltotal: sc.Variable,
203-
pulse_stride_offset: int,
204-
) -> sc.DataArray:
204+
pulse_stride_offset: int | None,
205+
) -> dict:
206+
"""
207+
Prepare the inputs required for the time-of-flight interpolation.
208+
This function is used when computing the time-of-flight for event data, and for
209+
computing the time-of-arrival for event data (as they both require guessing the
210+
pulse_stride_offset if not provided).
211+
212+
Parameters
213+
----------
214+
da:
215+
Data array with event data.
216+
lookup:
217+
Lookup table giving time-of-flight as a function of distance and time of
218+
arrival.
219+
ltotal:
220+
Total length of the flight path from the source to the detector.
221+
pulse_stride_offset:
222+
When pulse-skipping, the offset of the first pulse in the stride. This is
223+
typically zero but can be a small integer < pulse_stride.
224+
If None, a guess is made.
225+
"""
205226
etos = da.bins.coords["event_time_offset"].to(dtype=float, copy=False)
206227
eto_unit = elem_unit(etos)
207228

@@ -259,12 +280,34 @@ def _time_of_flight_data_events(
259280
pulse_index += pulse_stride_offset
260281
pulse_index %= pulse_stride
261282

262-
# Compute time-of-flight for all neutrons using the interpolator
263-
tofs = interp(
283+
return {
284+
"eto": etos,
285+
"pulse_index": pulse_index,
286+
"pulse_period": pulse_period,
287+
"interp": interp,
288+
"ltotal": ltotal,
289+
}
290+
291+
292+
def _time_of_flight_data_events(
293+
da: sc.DataArray,
294+
lookup: sc.DataArray,
295+
ltotal: sc.Variable,
296+
pulse_stride_offset: int | None,
297+
) -> sc.DataArray:
298+
inputs = _prepare_tof_interpolation_inputs(
299+
da=da,
300+
lookup=lookup,
264301
ltotal=ltotal,
265-
event_time_offset=etos,
266-
pulse_index=pulse_index,
267-
pulse_period=pulse_period,
302+
pulse_stride_offset=pulse_stride_offset,
303+
)
304+
305+
# Compute time-of-flight for all neutrons using the interpolator
306+
tofs = inputs["interp"](
307+
ltotal=inputs["ltotal"],
308+
event_time_offset=inputs["eto"],
309+
pulse_index=inputs["pulse_index"],
310+
pulse_period=inputs["pulse_period"],
268311
)
269312

270313
parts = da.bins.constituents
@@ -416,6 +459,53 @@ def monitor_time_of_flight_data(
416459
)
417460

418461

462+
def detector_time_of_arrival_data(
463+
detector_data: RawDetector[RunType],
464+
lookup: TimeOfFlightLookupTable,
465+
ltotal: DetectorLtotal[RunType],
466+
pulse_stride_offset: PulseStrideOffset,
467+
) -> ToaDetector[RunType]:
468+
"""
469+
Convert the time-of-flight data to time-of-arrival data using a lookup table.
470+
The output data will have a time-of-arrival coordinate.
471+
The time-of-arrival is the time since the neutron was emitted from the source.
472+
It is basically equal to event_time_offset + pulse_index * pulse_period.
473+
474+
Parameters
475+
----------
476+
da:
477+
Raw detector data loaded from a NeXus file, e.g., NXdetector containing
478+
NXevent_data.
479+
lookup:
480+
Lookup table giving time-of-flight as a function of distance and time of
481+
arrival.
482+
ltotal:
483+
Total length of the flight path from the source to the detector.
484+
pulse_stride_offset:
485+
When pulse-skipping, the offset of the first pulse in the stride. This is
486+
typically zero but can be a small integer < pulse_stride.
487+
"""
488+
if detector_data.bins is None:
489+
raise NotImplementedError(
490+
"Computing time-of-arrival in histogram mode is not implemented yet."
491+
)
492+
inputs = _prepare_tof_interpolation_inputs(
493+
da=detector_data,
494+
lookup=lookup,
495+
ltotal=ltotal,
496+
pulse_stride_offset=pulse_stride_offset,
497+
)
498+
parts = detector_data.bins.constituents
499+
parts["data"] = inputs["eto"]
500+
# The pulse index is None if pulse_stride == 1 (i.e., no pulse skipping)
501+
if inputs["pulse_index"] is not None:
502+
parts["data"] = parts["data"] + inputs["pulse_index"] * inputs["pulse_period"]
503+
result = detector_data.bins.assign_coords(
504+
toa=sc.bins(**parts, validate_indices=False)
505+
)
506+
return result
507+
508+
419509
def providers() -> tuple[Callable]:
420510
"""
421511
Providers of the time-of-flight workflow.
@@ -425,4 +515,5 @@ def providers() -> tuple[Callable]:
425515
monitor_time_of_flight_data,
426516
detector_ltotal_from_straight_line_approximation,
427517
monitor_ltotal_from_straight_line_approximation,
518+
detector_time_of_arrival_data,
428519
)

src/ess/reduce/time_of_flight/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,17 @@ class TofDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
3737
"""Detector data with time-of-flight coordinate."""
3838

3939

40+
class ToaDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
41+
"""Detector data with time-of-arrival coordinate.
42+
43+
When the pulse stride is 1 (i.e., no pulse skipping), the time-of-arrival is the
44+
same as the event_time_offset. When pulse skipping is used, the time-of-arrival is
45+
the event_time_offset + pulse_offset * pulse_period.
46+
This means that the time-of-arrival is basically the event_time_offset wrapped
47+
over the frame period instead of the pulse period
48+
(where frame_period = pulse_stride * pulse_period).
49+
"""
50+
51+
4052
class TofMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
4153
"""Monitor data with time-of-flight coordinate."""

tests/time_of_flight/unwrap_test.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010
from ess.reduce import time_of_flight
1111
from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun
12-
from ess.reduce.time_of_flight import GenericTofWorkflow, TofLookupTableWorkflow, fakes
12+
from ess.reduce.time_of_flight import (
13+
GenericTofWorkflow,
14+
PulsePeriod,
15+
TofLookupTableWorkflow,
16+
fakes,
17+
)
1318

1419
sl = pytest.importorskip("sciline")
1520

@@ -441,3 +446,63 @@ def test_unwrap_int(dtype, lut_workflow_psc_choppers) -> None:
441446
_validate_result_events(
442447
tofs=tofs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05
443448
)
449+
450+
451+
def test_compute_toa():
452+
distance = sc.scalar(80.0, unit="m")
453+
choppers = fakes.psc_choppers()
454+
455+
lut_wf = make_lut_workflow(
456+
choppers=choppers, neutrons=500_000, seed=1234, pulse_stride=1
457+
)
458+
459+
pl, _ = _make_workflow_event_mode(
460+
distance=distance,
461+
choppers=choppers,
462+
lut_workflow=lut_wf,
463+
seed=2,
464+
pulse_stride_offset=0,
465+
error_threshold=0.1,
466+
)
467+
468+
toas = pl.compute(time_of_flight.ToaDetector[SampleRun])
469+
470+
assert "toa" in toas.bins.coords
471+
raw = pl.compute(RawDetector[SampleRun])
472+
assert sc.allclose(toas.bins.coords["toa"], raw.bins.coords["event_time_offset"])
473+
474+
475+
def test_compute_toa_pulse_skipping():
476+
distance = sc.scalar(100.0, unit="m")
477+
choppers = fakes.pulse_skipping_choppers()
478+
479+
lut_wf = make_lut_workflow(
480+
choppers=choppers, neutrons=500_000, seed=1234, pulse_stride=2
481+
)
482+
483+
pl, _ = _make_workflow_event_mode(
484+
distance=distance,
485+
choppers=choppers,
486+
lut_workflow=lut_wf,
487+
seed=2,
488+
pulse_stride_offset=1,
489+
error_threshold=0.1,
490+
)
491+
492+
raw = pl.compute(RawDetector[SampleRun])
493+
494+
toas = pl.compute(time_of_flight.ToaDetector[SampleRun])
495+
496+
assert "toa" in toas.bins.coords
497+
pulse_period = lut_wf.compute(PulsePeriod)
498+
hist = toas.bins.concat().hist(
499+
toa=sc.array(
500+
dims=["toa"],
501+
values=[0, pulse_period.value, pulse_period.value * 2],
502+
unit=pulse_period.unit,
503+
).to(unit=toas.bins.coords["toa"].unit)
504+
)
505+
# There should be counts in both bins
506+
n = raw.sum().value
507+
assert hist.data[0].value > n / 5
508+
assert hist.data[1].value > n / 5

0 commit comments

Comments
 (0)