diff --git a/src/ess/reduce/time_of_flight/eto_to_tof.py b/src/ess/reduce/time_of_flight/eto_to_tof.py index caed57cb..9b735121 100644 --- a/src/ess/reduce/time_of_flight/eto_to_tof.py +++ b/src/ess/reduce/time_of_flight/eto_to_tof.py @@ -96,14 +96,14 @@ def __call__( def _time_of_flight_data_histogram( - da: sc.DataArray, lookup: sc.DataArray, ltotal: sc.Variable + da: sc.DataArray, lookup: sc.DataGroup, ltotal: sc.Variable ) -> sc.DataArray: # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files, # it may be called 'tof' or 'frame_time'. key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof", "frame_time"})) raw_eto = da.coords[key].to(dtype=float, copy=False) eto_unit = raw_eto.unit - pulse_period = lookup.coords["pulse_period"].to(unit=eto_unit) + pulse_period = lookup["pulse_period"].to(unit=eto_unit) # In histogram mode, because there is a wrap around at the end of the pulse, we # need to insert a bin edge at that exact location to avoid having the last bin @@ -117,7 +117,9 @@ def _time_of_flight_data_histogram( etos = rebinned.coords[key] # Create linear interpolator - interp = TofInterpolator(lookup, distance_unit=ltotal.unit, time_unit=eto_unit) + interp = TofInterpolator( + lookup["data"], distance_unit=ltotal.unit, time_unit=eto_unit + ) # Compute time-of-flight of the bin edges using the interpolator tofs = interp( @@ -199,7 +201,7 @@ def _guess_pulse_stride_offset( def _prepare_tof_interpolation_inputs( da: sc.DataArray, - lookup: sc.DataArray, + lookup: sc.DataGroup, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> dict: @@ -227,15 +229,17 @@ def _prepare_tof_interpolation_inputs( eto_unit = elem_unit(etos) # Create linear interpolator - interp = TofInterpolator(lookup, distance_unit=ltotal.unit, time_unit=eto_unit) + interp = TofInterpolator( + lookup["data"], distance_unit=ltotal.unit, time_unit=eto_unit + ) # Operate on events (broadcast distances to all events) ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"] etos = etos.bins.constituents["data"] pulse_index = None - pulse_period = lookup.coords["pulse_period"].to(unit=eto_unit) - pulse_stride = lookup.coords["pulse_stride"].value + pulse_period = lookup["pulse_period"].to(unit=eto_unit) + pulse_stride = lookup["pulse_stride"].value if pulse_stride > 1: # Compute a pulse index for every event: it is the index of the pulse within a @@ -291,7 +295,7 @@ def _prepare_tof_interpolation_inputs( def _time_of_flight_data_events( da: sc.DataArray, - lookup: sc.DataArray, + lookup: sc.DataGroup, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> sc.DataArray: @@ -375,7 +379,7 @@ def monitor_ltotal_from_straight_line_approximation( def _compute_tof_data( da: sc.DataArray, - lookup: sc.DataArray, + lookup: sc.DataGroup, ltotal: sc.Variable, pulse_stride_offset: int, ) -> sc.DataArray: @@ -503,7 +507,7 @@ def detector_time_of_arrival_data( result = detector_data.bins.assign_coords( toa=sc.bins(**parts, validate_indices=False) ) - return result + return ToaDetector[RunType](result) def providers() -> tuple[Callable]: diff --git a/src/ess/reduce/time_of_flight/fakes.py b/src/ess/reduce/time_of_flight/fakes.py index 94546abd..cc93023f 100644 --- a/src/ess/reduce/time_of_flight/fakes.py +++ b/src/ess/reduce/time_of_flight/fakes.py @@ -48,20 +48,13 @@ def __init__( self.source = source(pulses=self.npulses) # Convert the choppers to tof.Chopper - self.choppers = [ - tof_pkg.Chopper( - frequency=abs(ch.frequency), - direction=tof_pkg.AntiClockwise - if (ch.frequency.value > 0.0) - else tof_pkg.Clockwise, - open=ch.slit_begin, - close=ch.slit_end, - phase=ch.phase if ch.frequency.value > 0.0 else -ch.phase, - distance=sc.norm(ch.axle_position - source_position), - name=name, + self.choppers = [] + for name, ch in choppers.items(): + chop = tof_pkg.Chopper.from_diskchopper(ch, name=name) + chop.distance = sc.norm( + ch.axle_position - source_position.to(unit=ch.axle_position.unit) ) - for name, ch in choppers.items() - ] + self.choppers.append(chop) # Add detectors self.monitors = [ diff --git a/src/ess/reduce/time_of_flight/lut.py b/src/ess/reduce/time_of_flight/lut.py index 0897364f..18807d28 100644 --- a/src/ess/reduce/time_of_flight/lut.py +++ b/src/ess/reduce/time_of_flight/lut.py @@ -43,6 +43,8 @@ class SimulationResults: For a ``tof`` simulation, this is just the position of the detector where the events are recorded. For a ``McStas`` simulation, this is the distance between the source and the event monitor. + choppers: + The parameters of the choppers used in the simulation (if any). """ time_of_arrival: sc.Variable @@ -50,6 +52,7 @@ class SimulationResults: wavelength: sc.Variable weight: sc.Variable distance: sc.Variable + choppers: DiskChoppers[AnyRun] | None = None NumberOfSimulatedNeutrons = NewType("NumberOfSimulatedNeutrons", int) @@ -376,7 +379,25 @@ def make_tof_lookup_table( # In-place masking for better performance _mask_large_uncertainty(table, error_threshold) - return TimeOfFlightLookupTable(table) + out = sc.DataGroup( + { + "data": table, + "pulse_period": pulse_period, + "pulse_stride": sc.scalar(pulse_stride, unit=None), + "distance_resolution": table.coords["distance"][1] + - table.coords["distance"][0], + "time_resolution": table.coords["event_time_offset"][1] + - table.coords["event_time_offset"][0], + "error_threshold": sc.scalar(error_threshold), + } + ) + + if simulation.choppers is not None: + out['choppers'] = sc.DataGroup( + {k: sc.DataGroup(ch.as_dict()) for k, ch in simulation.choppers.items()} + ) + + return TimeOfFlightLookupTable(out) def simulate_chopper_cascade_using_tof( @@ -412,22 +433,14 @@ def simulate_chopper_cascade_using_tof( """ import tof - tof_choppers = [ - tof.Chopper( - frequency=abs(ch.frequency), - direction=tof.AntiClockwise - if (ch.frequency.value > 0.0) - else tof.Clockwise, - open=ch.slit_begin, - close=ch.slit_end, - phase=ch.phase if ch.frequency.value > 0.0 else -ch.phase, - distance=sc.norm( - ch.axle_position - source_position.to(unit=ch.axle_position.unit) - ), - name=name, + tof_choppers = [] + for name, ch in choppers.items(): + chop = tof.Chopper.from_diskchopper(ch, name=name) + chop.distance = sc.norm( + ch.axle_position - source_position.to(unit=ch.axle_position.unit) ) - for name, ch in choppers.items() - ] + tof_choppers.append(chop) + source = tof.Source( facility=facility, neutrons=neutrons, pulses=pulse_stride, seed=seed ) @@ -454,6 +467,7 @@ def simulate_chopper_cascade_using_tof( wavelength=events.coords["wavelength"], weight=events.data, distance=furthest_chopper.distance, + choppers=choppers, ) diff --git a/src/ess/reduce/time_of_flight/types.py b/src/ess/reduce/time_of_flight/types.py index 060d3142..6a02878b 100644 --- a/src/ess/reduce/time_of_flight/types.py +++ b/src/ess/reduce/time_of_flight/types.py @@ -12,7 +12,7 @@ """Filename of the time-of-flight lookup table.""" -TimeOfFlightLookupTable = NewType("TimeOfFlightLookupTable", sc.DataArray) +TimeOfFlightLookupTable = NewType("TimeOfFlightLookupTable", sc.DataGroup) """ Lookup table giving time-of-flight as a function of distance and time of arrival. """ diff --git a/tests/time_of_flight/lut_test.py b/tests/time_of_flight/lut_test.py index edbaec71..df721a6c 100644 --- a/tests/time_of_flight/lut_test.py +++ b/tests/time_of_flight/lut_test.py @@ -28,15 +28,15 @@ def test_lut_workflow_computes_table(): table = wf.compute(time_of_flight.TimeOfFlightLookupTable) - assert table.coords['distance'].min() < lmin - assert table.coords['distance'].max() > lmax - assert table.coords['event_time_offset'].max() == sc.scalar(1 / 14, unit='s').to( - unit=table.coords['event_time_offset'].unit - ) - assert sc.isclose(table.coords['distance_resolution'], dres) + assert table['data'].coords['distance'].min() < lmin + assert table['data'].coords['distance'].max() > lmax + assert table['data'].coords['event_time_offset'].max() == sc.scalar( + 1 / 14, unit='s' + ).to(unit=table['data'].coords['event_time_offset'].unit) + assert sc.isclose(table['distance_resolution'], dres) # Note that the time resolution is not exactly preserved since we want the table to # span exactly the frame period. - assert sc.isclose(table.coords['time_resolution'], tres, rtol=sc.scalar(0.01)) + assert sc.isclose(table['time_resolution'], tres, rtol=sc.scalar(0.01)) def test_lut_workflow_computes_table_in_chunks(): @@ -58,15 +58,15 @@ def test_lut_workflow_computes_table_in_chunks(): table = wf.compute(time_of_flight.TimeOfFlightLookupTable) - assert table.coords['distance'].min() < lmin - assert table.coords['distance'].max() > lmax - assert table.coords['event_time_offset'].max() == sc.scalar(1 / 14, unit='s').to( - unit=table.coords['event_time_offset'].unit - ) - assert sc.isclose(table.coords['distance_resolution'], dres) + assert table['data'].coords['distance'].min() < lmin + assert table['data'].coords['distance'].max() > lmax + assert table['data'].coords['event_time_offset'].max() == sc.scalar( + 1 / 14, unit='s' + ).to(unit=table['data'].coords['event_time_offset'].unit) + assert sc.isclose(table['distance_resolution'], dres) # Note that the time resolution is not exactly preserved since we want the table to # span exactly the frame period. - assert sc.isclose(table.coords['time_resolution'], tres, rtol=sc.scalar(0.01)) + assert sc.isclose(table['time_resolution'], tres, rtol=sc.scalar(0.01)) def test_lut_workflow_pulse_skipping(): @@ -87,9 +87,9 @@ def test_lut_workflow_pulse_skipping(): table = wf.compute(time_of_flight.TimeOfFlightLookupTable) - assert table.coords['event_time_offset'].max() == 2 * sc.scalar( + assert table['data'].coords['event_time_offset'].max() == 2 * sc.scalar( 1 / 14, unit='s' - ).to(unit=table.coords['event_time_offset'].unit) + ).to(unit=table['data'].coords['event_time_offset'].unit) def test_lut_workflow_non_exact_distance_range(): @@ -110,6 +110,6 @@ def test_lut_workflow_non_exact_distance_range(): table = wf.compute(time_of_flight.TimeOfFlightLookupTable) - assert table.coords['distance'].min() < lmin - assert table.coords['distance'].max() > lmax - assert sc.isclose(table.coords['distance_resolution'], dres) + assert table['data'].coords['distance'].min() < lmin + assert table['data'].coords['distance'].max() > lmax + assert sc.isclose(table['distance_resolution'], dres) diff --git a/tests/time_of_flight/workflow_test.py b/tests/time_of_flight/workflow_test.py index 2f6ad0aa..2f4bba29 100644 --- a/tests/time_of_flight/workflow_test.py +++ b/tests/time_of_flight/workflow_test.py @@ -71,7 +71,11 @@ def test_TofLookupTableWorkflow_can_compute_tof_lut(): ) wf[time_of_flight.SourcePosition] = fakes.source_position() lut = wf.compute(time_of_flight.TimeOfFlightLookupTable) - assert isinstance(lut, sc.DataArray) + assert "data" in lut + assert "distance_resolution" in lut + assert "time_resolution" in lut + assert "pulse_stride" in lut + assert "pulse_period" in lut def test_GenericTofWorkflow_with_tof_lut_from_tof_simulation(