Skip to content

Commit a73b2bb

Browse files
committed
feat: add Reference row to reduction table to let user specify what reference to use
1 parent 044b4eb commit a73b2bb

File tree

1 file changed

+125
-69
lines changed

1 file changed

+125
-69
lines changed

src/ess/reflectometry/gui.py

Lines changed: 125 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,24 @@
3535
from ess.reflectometry.workflow import with_filenames
3636

3737

38+
def _get_unique_names(df):
39+
# Create labels with Sample name and runs
40+
labels = [
41+
f"{params['Sample']} ({','.join(params['Runs'])})"
42+
for (_, params) in df.iterrows()
43+
]
44+
duplicated_name_counter = {}
45+
unique = []
46+
for i, name in enumerate(labels):
47+
if name not in labels[:i]:
48+
unique.append(name)
49+
else:
50+
duplicated_name_counter.setdefault(name, 0)
51+
duplicated_name_counter[name] += 1
52+
unique.append(f'{name}_{duplicated_name_counter[name]}')
53+
return unique
54+
55+
3856
def _get_selected_rows(grid):
3957
return (
4058
pd.concat(
@@ -386,8 +404,8 @@ def sync(self, *_):
386404
db["user_reference"] = self.reference_table.data
387405

388406
db["user_runs"] = self.sync_runs_table(db)
389-
db["user_reduction"] = self.sync_reduction_table(db)
390407
db["user_reference"] = self.sync_reference_table(db)
408+
db["user_reduction"] = self.sync_reduction_table(db)
391409

392410
self.runs_table.data = db["user_runs"]
393411
self.reduction_table.data = db["user_reduction"]
@@ -449,7 +467,7 @@ def _init_runs_table_component(self):
449467
self.run_number_min.observe(self.sync, names='value')
450468
self.run_number_max.observe(self.sync, names='value')
451469
run_number_filter = widgets.HBox(
452-
[self.run_number_min, widgets.Label("<=Run<="), self.run_number_max]
470+
[self.run_number_min, widgets.Label("&le;Run&le;"), self.run_number_max]
453471
)
454472
self.runs_table_component = widgets.VBox(
455473
[
@@ -491,6 +509,9 @@ def add_row(_):
491509
'QBins': 391,
492510
'QStart': 0.01,
493511
'QStop': 0.3,
512+
'Reference': self.reference_table.data.iloc[0]['Sample']
513+
if len(self.reference_table.data) > 0
514+
else pd.NA,
494515
'Scale': 1.0,
495516
}
496517
]
@@ -557,6 +578,16 @@ def _init_display_component(self):
557578
]
558579
)
559580

581+
def _init_settings_component(self):
582+
self.settings_component = widgets.VBox(
583+
[
584+
widgets.Label("This is the settings tab"),
585+
widgets.Label("Reference runs"),
586+
self.reference_table,
587+
],
588+
layout={"width": "100%"},
589+
)
590+
560591
def __init__(self):
561592
self.text_log = widgets.VBox([])
562593
self._path = None
@@ -609,6 +640,7 @@ def __init__(self):
609640
self._init_runs_table_component()
610641
self._init_reduction_table_component()
611642
self._init_display_component()
643+
self._init_settings_component()
612644

613645
tab_data = widgets.VBox(
614646
[
@@ -621,14 +653,7 @@ def __init__(self):
621653
self.display_component,
622654
]
623655
)
624-
tab_settings = widgets.VBox(
625-
[
626-
widgets.Label("This is the settings tab"),
627-
widgets.Label("Reference runs"),
628-
self.reference_table,
629-
],
630-
layout={"width": "100%"},
631-
)
656+
tab_settings = self.settings_component
632657
tab_log = widgets.VBox(
633658
[widgets.Label("Messages"), self.text_log],
634659
layout={"width": "100%"},
@@ -746,14 +771,17 @@ def sync_runs_table(self, db):
746771
db['run_number_max'] >= df['Run'].astype(int)
747772
]
748773
self._setdefault(df, "Exclude", False)
774+
self._setdefault(df, "Reference", False)
749775
self._setdefault(df, "Comment", "") # Add default empty comment
750-
df = self._ordercolumns(df, 'Run', 'Sample', 'Angle', 'Exclude', 'Comment')
776+
df = self._ordercolumns(
777+
df, 'Run', 'Sample', 'Angle', 'Exclude', 'Reference', 'Comment'
778+
)
751779
return df.sort_values(by='Run')
752780

753781
def sync_reduction_table(self, db):
754782
df = db["user_runs"]
755783
df = (
756-
df[df["Sample"] != "sm5"][~df["Exclude"]]
784+
df[~df["Reference"]][~df["Exclude"]]
757785
.groupby(["Sample", "Angle"], as_index=False)
758786
.agg(Runs=("Run", tuple))
759787
.sort_values(["Sample", "Angle"])
@@ -768,13 +796,20 @@ def sync_reduction_table(self, db):
768796
self._setdefault(df, "QStart", 0.01)
769797
self._setdefault(df, "QStop", 0.3)
770798
self._setdefault(df, "Scale", 1.0)
799+
self._setdefault(
800+
df,
801+
"Reference",
802+
db['user_reference'].iloc[0]['Sample']
803+
if len(db['user_reference']) > 0
804+
else pd.NA,
805+
)
771806
df = self._ordercolumns(df, 'Sample', 'Angle', 'Runs')
772807
return df.sort_values(["Sample", "Angle"])
773808

774809
def sync_reference_table(self, db):
775810
df = db["user_runs"]
776811
df = (
777-
df[df["Sample"] == "sm5"][~df["Exclude"]]
812+
df[df["Reference"]][~df["Exclude"]]
778813
.groupby(["Sample", "Angle"], as_index=False)
779814
.agg(Runs=("Run", tuple))
780815
.sort_values(["Sample", "Angle"])
@@ -816,30 +851,21 @@ def display_results(self):
816851
for _, row in df.iterrows()
817852
if (key := self.get_row_key(row)) in self.results
818853
]
854+
labels = [
855+
label
856+
for label, (_, row) in zip(
857+
_get_unique_names(df), df.iterrows(), strict=True
858+
)
859+
if (key := self.get_row_key(row)) in self.results
860+
]
861+
819862
if len(results) == len(df):
820863
break
821864
# No results were found for some of the selected rows.
822865
# It hasn't been computed yet, so compute it and try again.
823866
self.run_workflow()
824867

825-
def get_unique_names(df):
826-
# Create labels with Sample name and runs
827-
labels = [
828-
f"{params['Sample']} ({','.join(params['Runs'])})"
829-
for (_, params) in df.iterrows()
830-
]
831-
duplicated_name_counter = {}
832-
unique = []
833-
for i, name in enumerate(labels):
834-
if name not in labels[:i]:
835-
unique.append(name)
836-
else:
837-
duplicated_name_counter.setdefault(name, 0)
838-
duplicated_name_counter[name] += 1
839-
unique.append(f'{name}_{duplicated_name_counter[name]}')
840-
return unique
841-
842-
results = dict(zip(get_unique_names(df), results, strict=True))
868+
results = dict(zip(labels, results, strict=True))
843869

844870
q4toggle = widgets.ToggleButton(value=False, description="R*Q^4")
845871
plot_box = widgets.VBox(
@@ -922,13 +948,22 @@ def get_filepath_from_run(self, run):
922948
)
923949
return os.path.join(self.path, fname)
924950

951+
def get_reference_for_row(self, row):
952+
if 'Reference' in row:
953+
return self.reference_table.data[
954+
self.reference_table.data['Sample'] == row['Reference']
955+
]
956+
else:
957+
return None
958+
925959
def get_row_key(self, row):
926-
reference_metadata = (
927-
tuple(self.reference_table.data.iloc[0])
928-
if len(self.reference_table.data) > 0
929-
else (None,)
960+
reference = self.get_reference_for_row(row)
961+
return (
962+
tuple(row),
963+
tuple(tuple(row) for _, row in reference.iterrows())
964+
if reference is not None
965+
else None,
930966
)
931-
return (tuple(row), tuple(reference_metadata))
932967

933968
def get_selected_rows(self):
934969
chunks = [
@@ -943,7 +978,7 @@ def get_selected_rows(self):
943978

944979
def run_workflow(self):
945980
sample_df = self.get_selected_rows()
946-
reference_df = self.reference_table.data.iloc[0]
981+
used_references = set(sample_df['Reference'])
947982

948983
workflow = amor.AmorWorkflow()
949984
workflow[SampleSize[SampleRun]] = sc.scalar(10, unit='mm')
@@ -952,46 +987,67 @@ def run_workflow(self):
952987
workflow[ChopperPhase[ReferenceRun]] = sc.scalar(7.5, unit='deg')
953988
workflow[ChopperPhase[SampleRun]] = sc.scalar(7.5, unit='deg')
954989

955-
workflow[WavelengthBins] = sc.geomspace(
956-
'wavelength',
957-
reference_df['Lmin'],
958-
reference_df['Lmax'],
959-
2001,
960-
unit='angstrom',
961-
)
962-
963-
workflow[YIndexLimits] = (
964-
sc.scalar(reference_df['Ymin']),
965-
sc.scalar(reference_df['Ymax']),
966-
)
967-
workflow[ZIndexLimits] = (
968-
sc.scalar(reference_df['Zmin']),
969-
sc.scalar(reference_df['Zmax']),
970-
)
971-
972-
progress = widgets.IntProgress(min=0, max=len(sample_df))
990+
progress = widgets.IntProgress(min=0, max=len(sample_df) + len(used_references))
973991
self.log_progress(progress)
974992

975-
if (key := self.get_row_key(reference_df)) in self.results:
976-
reference_result = self.results[key]
977-
else:
978-
reference_result = with_filenames(
979-
workflow,
980-
ReferenceRun,
981-
list(map(self.get_filepath_from_run, reference_df["Runs"])),
982-
).compute(ReducedReference)
983-
self.set_result(reference_df, reference_result)
984-
985-
workflow[ReducedReference] = reference_result
986-
progress.value += 1
987-
988993
for _, params in sample_df.iterrows():
989994
if (key := self.get_row_key(params)) in self.results:
990995
progress.value += 1
991996
continue
992997

998+
reference_rows_matching_selected_reference = self.get_reference_for_row(
999+
params
1000+
)
1001+
if len(reference_rows_matching_selected_reference) < 1:
1002+
self.log(
1003+
f'Reference "{params["Reference"]}" '
1004+
'does not exist in the reference list.'
1005+
)
1006+
continue
1007+
if len(reference_rows_matching_selected_reference) > 1:
1008+
self.log(
1009+
f'Reference "{params["Reference"]}" does not refer to a unique '
1010+
'refererence! Make sure that the reference '
1011+
'"Sample" names are unique.'
1012+
)
1013+
1014+
reference_row = reference_rows_matching_selected_reference.iloc[0]
1015+
1016+
wf = workflow.copy()
1017+
wf[WavelengthBins] = sc.geomspace(
1018+
'wavelength',
1019+
reference_row['Lmin'],
1020+
reference_row['Lmax'],
1021+
2001,
1022+
unit='angstrom',
1023+
)
1024+
wf[YIndexLimits] = (
1025+
sc.scalar(reference_row['Ymin']),
1026+
sc.scalar(reference_row['Ymax']),
1027+
)
1028+
wf[ZIndexLimits] = (
1029+
sc.scalar(reference_row['Zmin']),
1030+
sc.scalar(reference_row['Zmax']),
1031+
)
1032+
1033+
if (key := self.get_row_key(reference_row)) in self.results:
1034+
reference_result = self.results[key]
1035+
else:
1036+
reference_result = with_filenames(
1037+
wf,
1038+
ReferenceRun,
1039+
list(map(self.get_filepath_from_run, reference_row["Runs"])),
1040+
).compute(ReducedReference)
1041+
self.set_result(reference_row, reference_result)
1042+
1043+
if params['Reference'] in used_references:
1044+
progress.value += 1
1045+
used_references.remove(params['Reference'])
1046+
1047+
wf[ReducedReference] = reference_result
1048+
9931049
wf = with_filenames(
994-
workflow,
1050+
wf,
9951051
SampleRun,
9961052
list(map(self.get_filepath_from_run, params['Runs'])),
9971053
)

0 commit comments

Comments
 (0)