Skip to content

Commit c826b4d

Browse files
committed
Return corrections in topological order
1 parent 37f8aa6 commit c826b4d

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

src/essreflectometry/orso.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
of reference runs and only use the metadata of the sample run.
77
"""
88

9+
import graphlib
910
import os
1011
import platform
1112
from datetime import datetime, timezone
@@ -183,15 +184,19 @@ def find_corrections(task_graph: TaskGraph) -> list[str]:
183184
Returns
184185
-------
185186
:
186-
List of corrections.
187+
List of corrections in the order they are applied in.
187188
"""
188-
return sorted(
189-
[
190-
c
191-
for key in task_graph.keys()
192-
if (c := _CORRECTIONS_BY_GRAPH_KEY.get(key, None)) is not None
193-
]
189+
toposort = graphlib.TopologicalSorter(
190+
{
191+
key: tuple(provider.arg_spec.keys())
192+
for key, provider in task_graph._graph.items()
193+
}
194194
)
195+
return [
196+
c
197+
for key in toposort.static_order()
198+
if (c := _CORRECTIONS_BY_GRAPH_KEY.get(key, None)) is not None
199+
]
195200

196201

197202
providers = (

tests/amor/pipeline_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ def test_run_pipeline(amor_pipeline: sciline.Pipeline):
5353

5454
def test_find_corrections(amor_pipeline: sciline.Pipeline):
5555
graph = amor_pipeline.get(orso.OrsoIofQDataset)
56-
assert sorted(orso.find_corrections(graph)) == [
56+
# In topological order
57+
assert orso.find_corrections(graph) == [
58+
'supermirror calibration',
5759
'chopper ToF correction',
5860
'footprint correction',
59-
'supermirror calibration',
6061
'total counts',
6162
]

0 commit comments

Comments
 (0)