Skip to content

Commit 67daf52

Browse files
committed
feat: add workflow helper for combining files
1 parent a6f14eb commit 67daf52

File tree

4 files changed

+150
-10
lines changed

4 files changed

+150
-10
lines changed

src/ess/reflectometry/conversions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .types import (
1010
BeamDivergenceLimits,
1111
DataWithScatteringCoordinates,
12+
DetectorRotation,
1213
Gravity,
1314
IncidentBeam,
1415
MaskedData,
@@ -109,6 +110,7 @@ def specular_reflection(
109110
incident_beam: IncidentBeam[RunType],
110111
sample_position: SamplePosition[RunType],
111112
sample_rotation: SampleRotation[RunType],
113+
detector_rotation: DetectorRotation[RunType],
112114
gravity: Gravity,
113115
) -> SpecularReflectionCoordTransformGraph[RunType]:
114116
"""
@@ -127,6 +129,7 @@ def specular_reflection(
127129
"incident_beam": lambda: incident_beam,
128130
"sample_position": lambda: sample_position,
129131
"sample_rotation": lambda: sample_rotation,
132+
"detector_rotation": lambda: detector_rotation,
130133
"gravity": lambda: gravity,
131134
}
132135
return SpecularReflectionCoordTransformGraph(graph)
@@ -136,7 +139,10 @@ def add_coords(
136139
da: ReducibleDetectorData[RunType],
137140
graph: SpecularReflectionCoordTransformGraph[RunType],
138141
) -> DataWithScatteringCoordinates[RunType]:
139-
da = da.transform_coords(["theta", "wavelength", "Q"], graph=graph)
142+
da = da.transform_coords(
143+
["theta", "wavelength", "Q", "detector_rotation"], graph=graph
144+
)
145+
da.coords.set_aligned('detector_rotation', False)
140146
da.coords["z_index"] = sc.arange(
141147
"row", 0, da.sizes["blade"] * da.sizes["wire"], unit=None
142148
).fold("row", sizes={dim: da.sizes[dim] for dim in ("blade", "wire")})

src/ess/reflectometry/orso.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
OrsoSample = NewType("OrsoSample", data_source.Sample)
6060
"""ORSO sample."""
6161

62+
OrsoSampleFilenames = NewType("OrsoSampleFilenames", list[orso_base.File])
63+
"""Collection of filenames used to create the ORSO file"""
64+
6265

6366
def parse_orso_experiment(filename: Filename[SampleRun]) -> OrsoExperiment:
6467
"""Parse ORSO experiment metadata from raw NeXus data."""
@@ -107,8 +110,13 @@ def parse_orso_sample(filename: Filename[SampleRun]) -> OrsoSample:
107110
)
108111

109112

113+
def orso_data_files(filename: Filename[SampleRun]) -> OrsoSampleFilenames:
114+
'''Collects names of files used in the experiment'''
115+
return [orso_base.File(file=os.path.basename(filename))]
116+
117+
110118
def build_orso_measurement(
111-
sample_filename: Filename[SampleRun],
119+
sample_filenames: OrsoSampleFilenames,
112120
reference_filename: Filename[ReferenceRun],
113121
instrument: OrsoInstrument,
114122
) -> OrsoMeasurement:
@@ -127,7 +135,7 @@ def build_orso_measurement(
127135
return OrsoMeasurement(
128136
data_source.Measurement(
129137
instrument_settings=instrument,
130-
data_files=[orso_base.File(file=os.path.basename(sample_filename))],
138+
data_files=sample_filenames,
131139
additional_files=additional_files,
132140
)
133141
)
@@ -220,4 +228,5 @@ def find_corrections(task_graph: TaskGraph) -> list[str]:
220228
parse_orso_experiment,
221229
parse_orso_owner,
222230
parse_orso_sample,
231+
orso_data_files,
223232
)

src/ess/reflectometry/workflow.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from collections.abc import Hashable, Sequence
2+
from itertools import chain
3+
4+
import pandas as pd
5+
import sciline
6+
import scipp as sc
7+
8+
from ess.amor.types import RawChopper
9+
from ess.reflectometry.orso import (
10+
OrsoExperiment,
11+
OrsoOwner,
12+
OrsoSample,
13+
OrsoSampleFilenames,
14+
)
15+
from ess.reflectometry.types import (
16+
Filename,
17+
FootprintCorrectedData,
18+
RunType,
19+
SampleRotation,
20+
SampleRun,
21+
)
22+
23+
24+
def _concatenate_event_lists(*das):
25+
return (
26+
sc.reduce(das)
27+
.bins.concat()
28+
.assign_coords(
29+
{
30+
name: das[0].coords[name]
31+
for name in ('position', 'sample_rotation', 'detector_rotation')
32+
}
33+
)
34+
)
35+
36+
37+
def _any_value(x, *_):
38+
return x
39+
40+
41+
def _concatenate_lists(*x):
42+
return list(chain(*x))
43+
44+
45+
def with_filenames(
46+
workflow, runtype: Hashable, runs: Sequence[Filename[RunType]]
47+
) -> sciline.Pipeline:
48+
axis_name = f'{str(runtype).lower()}_runs'
49+
df = pd.DataFrame({Filename[runtype]: runs}).rename_axis(axis_name)
50+
wf = workflow.copy()
51+
52+
mapped = wf.map(df)
53+
54+
wf[FootprintCorrectedData[runtype]] = mapped[
55+
FootprintCorrectedData[runtype]
56+
].reduce(index=axis_name, func=_concatenate_event_lists)
57+
wf[RawChopper[runtype]] = mapped[RawChopper[runtype]].reduce(
58+
index=axis_name, func=_any_value
59+
)
60+
wf[SampleRotation[runtype]] = mapped[SampleRotation[runtype]].reduce(
61+
index=axis_name, func=_any_value
62+
)
63+
64+
if runtype is SampleRun:
65+
wf[OrsoSample] = mapped[OrsoSample].reduce(index=axis_name, func=_any_value)
66+
wf[OrsoExperiment] = mapped[OrsoExperiment].reduce(
67+
index=axis_name, func=_any_value
68+
)
69+
wf[OrsoOwner] = mapped[OrsoOwner].reduce(index=axis_name, func=lambda x, *_: x)
70+
wf[OrsoSampleFilenames] = mapped[OrsoSampleFilenames].reduce(
71+
# When we don't map over filenames
72+
# each OrsoSampleFilenames is a list with a single entry.
73+
index=axis_name,
74+
func=_concatenate_lists,
75+
)
76+
return wf

tests/amor/pipeline_test.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sciline
88
import scipp as sc
99
from orsopy import fileio
10+
from scipp.testing import assert_allclose
1011

1112
from ess import amor
1213
from ess.amor import data # noqa: F401
@@ -23,6 +24,7 @@
2324
YIndexLimits,
2425
ZIndexLimits,
2526
)
27+
from ess.reflectometry.workflow import with_filenames
2628

2729

2830
@pytest.fixture
@@ -34,6 +36,9 @@ def amor_pipeline() -> sciline.Pipeline:
3436
pl[WavelengthBins] = sc.geomspace("wavelength", 2.8, 12, 300, unit="angstrom")
3537
pl[YIndexLimits] = sc.scalar(11, unit=None), sc.scalar(41, unit=None)
3638
pl[ZIndexLimits] = sc.scalar(80, unit=None), sc.scalar(370, unit=None)
39+
pl[QBins] = sc.geomspace(
40+
dim="Q", start=0.005, stop=0.115, num=391, unit="1/angstrom"
41+
)
3742

3843
# The sample rotation value in the file is slightly off, so we set it manually
3944
pl[SampleRotation[ReferenceRun]] = sc.scalar(0.65, unit="deg")
@@ -46,19 +51,15 @@ def amor_pipeline() -> sciline.Pipeline:
4651
contact="[email protected]",
4752
)
4853
)
49-
50-
# The sample rotation value in the file is slightly off, so we set it manually
51-
pl[SampleRotation[SampleRun]] = sc.scalar(0.85, unit="deg")
52-
pl[Filename[SampleRun]] = amor.data.amor_sample_run(608)
53-
pl[QBins] = sc.geomspace(
54-
dim="Q", start=0.005, stop=0.115, num=391, unit="1/angstrom"
55-
)
5654
return pl
5755

5856

5957
@pytest.mark.filterwarnings("ignore:Failed to convert .* into a transformation")
6058
@pytest.mark.filterwarnings("ignore:Invalid transformation, missing attribute")
6159
def test_run_data_pipeline(amor_pipeline: sciline.Pipeline):
60+
# The sample rotation value in the file is slightly off, so we set it manually
61+
amor_pipeline[SampleRotation[SampleRun]] = sc.scalar(0.85, unit="deg")
62+
amor_pipeline[Filename[SampleRun]] = amor.data.amor_sample_run(608)
6263
res = amor_pipeline.compute(ReflectivityOverQ)
6364
assert "Q" in res.coords
6465
assert "Q_resolution" in res.coords
@@ -67,6 +68,8 @@ def test_run_data_pipeline(amor_pipeline: sciline.Pipeline):
6768
@pytest.mark.filterwarnings("ignore:Failed to convert .* into a transformation")
6869
@pytest.mark.filterwarnings("ignore:Invalid transformation, missing attribute")
6970
def test_run_full_pipeline(amor_pipeline: sciline.Pipeline):
71+
amor_pipeline[SampleRotation[SampleRun]] = sc.scalar(0.85, unit="deg")
72+
amor_pipeline[Filename[SampleRun]] = amor.data.amor_sample_run(608)
7073
res = amor_pipeline.compute(orso.OrsoIofQDataset)
7174
assert res.info.data_source.experiment.instrument == "Amor"
7275
assert res.info.reduction.software.name == "ess.reflectometry"
@@ -75,7 +78,53 @@ def test_run_full_pipeline(amor_pipeline: sciline.Pipeline):
7578
assert np.all(res.data[:, 1] >= 0)
7679

7780

81+
@pytest.mark.filterwarnings("ignore:Failed to convert .* into a transformation")
82+
@pytest.mark.filterwarnings("ignore:Invalid transformation, missing attribute")
83+
def test_pipeline_can_compute_reflectivity_merging_events_from_multiple_runs(
84+
amor_pipeline: sciline.Pipeline,
85+
):
86+
sample_runs = [
87+
amor.data.amor_sample_run(608),
88+
amor.data.amor_sample_run(609),
89+
]
90+
pipeline = with_filenames(amor_pipeline, SampleRun, sample_runs)
91+
pipeline[SampleRotation[SampleRun]] = pipeline.compute(
92+
SampleRotation[SampleRun]
93+
) + sc.scalar(0.05, unit="deg")
94+
result = pipeline.compute(ReflectivityOverQ)
95+
assert result.dims == ('Q',)
96+
97+
98+
@pytest.mark.filterwarnings("ignore:Failed to convert .* into a transformation")
99+
@pytest.mark.filterwarnings("ignore:Invalid transformation, missing attribute")
100+
def test_pipeline_merging_events_result_unchanged(amor_pipeline: sciline.Pipeline):
101+
sample_runs = [
102+
amor.data.amor_sample_run(608),
103+
]
104+
pipeline = with_filenames(amor_pipeline, SampleRun, sample_runs)
105+
pipeline[SampleRotation[SampleRun]] = pipeline.compute(
106+
SampleRotation[SampleRun]
107+
) + sc.scalar(0.05, unit="deg")
108+
result = pipeline.compute(ReflectivityOverQ).hist()
109+
sample_runs = [
110+
amor.data.amor_sample_run(608),
111+
amor.data.amor_sample_run(608),
112+
]
113+
pipeline = with_filenames(amor_pipeline, SampleRun, sample_runs)
114+
pipeline[SampleRotation[SampleRun]] = pipeline.compute(
115+
SampleRotation[SampleRun]
116+
) + sc.scalar(0.05, unit="deg")
117+
result2 = pipeline.compute(ReflectivityOverQ).hist()
118+
assert_allclose(
119+
2 * sc.values(result.data), sc.values(result2.data), rtol=sc.scalar(1e-6)
120+
)
121+
assert_allclose(
122+
2 * sc.variances(result.data), sc.variances(result2.data), rtol=sc.scalar(1e-6)
123+
)
124+
125+
78126
def test_find_corrections(amor_pipeline: sciline.Pipeline):
127+
amor_pipeline[Filename[SampleRun]] = amor.data.amor_sample_run(608)
79128
graph = amor_pipeline.get(orso.OrsoIofQDataset)
80129
# In topological order
81130
assert orso.find_corrections(graph) == [

0 commit comments

Comments
 (0)