Skip to content

Commit a3de58a

Browse files
authored
Merge pull request #93 from usnistgov/2603_APS_12IDB
Updates from 12IDB Beamtime
2 parents 24327d4 + 5b3ca41 commit a3de58a

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

AFL/double_agent/AcquisitionFunction.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
output_prefix: Optional[str] = None,
8888
output_variable: str = "next_compositions",
8989
decision_rtol: float = 0.05,
90+
random_fraction: float = 0.0,
9091
excluded_comps_variables: Optional[List[str]] = None,
9192
excluded_comps_dim: Optional[str] = None,
9293
exclusion_radius: float = 1e-3,
@@ -108,6 +109,11 @@ def __init__(
108109
self.grid_variable = grid_variable
109110
self.grid_dim = grid_dim
110111
self.decision_rtol = decision_rtol
112+
if random_fraction < 0.0 or random_fraction > 1.0:
113+
raise ValueError(
114+
f"random_fraction must be within [0, 1], got {random_fraction}."
115+
)
116+
self.random_fraction = random_fraction
111117
self.exclusion_radius = exclusion_radius
112118

113119
def calculate(self, dataset: xr.Dataset) -> Self:
@@ -225,8 +231,9 @@ def get_next_samples(self, dataset: xr.Dataset) -> None:
225231
"""Choose the next compositions by evaluating the decision surface.
226232
227233
This method finds all compositions that are within decision_rtol of the maximum values
228-
of the decision surface. From this set of compositions, it randomly chooses count
229-
compositions as the next sample compositions.
234+
of the decision surface, then performs epsilon-greedy sampling for each requested pick:
235+
with probability random_fraction, pick from all valid points; otherwise pick from the
236+
top decision_rtol set.
230237
231238
Parameters
232239
----------
@@ -252,28 +259,60 @@ def get_next_samples(self, dataset: xr.Dataset) -> None:
252259
{self.grid_dim: np.arange(dataset.sizes[self.grid_dim])}
253260
)
254261

262+
decision_values = dataset["decision_surface"].values
263+
valid_mask = np.isfinite(decision_values)
264+
if not np.any(valid_mask):
265+
raise AcquisitionError(
266+
"Decision surface does not contain any finite values."
267+
)
268+
269+
all_indices = dataset[self.grid_dim].values
270+
valid_indices = all_indices[valid_mask]
271+
if len(valid_indices) < self.count:
272+
raise AcquisitionError(
273+
(
274+
"Unable to find enough valid gridpoints in decision surface to "
275+
f"sample {self.count} points."
276+
)
277+
)
278+
255279
# find indices of all samples within self.decision_rtol of the maximum
256280
close_mask = np.isclose(
257-
dataset.decision_surface,
258-
dataset.decision_surface.max(),
281+
decision_values,
282+
np.max(decision_values[valid_mask]),
259283
rtol=self.decision_rtol,
260284
atol=0,
261285
)
262-
indices = dataset.sel({self.grid_dim: close_mask})[self.grid_dim].values
286+
close_mask &= valid_mask
287+
close_indices = all_indices[close_mask]
263288

264-
if len(indices) < self.count:
265-
raise AcquisitionError(
266-
(
267-
"""Unable to find gridpoint in decision surface that satisfies all constraints. """
268-
f"""This often occurs when acquisition_rtol (currently {self.decision_rtol}) """
269-
f"""is too low or when the exclusion_radius (currently {self.exclusion_radius}) """
270-
"""is too high for the current problem state."""
289+
all_pool = set(valid_indices.tolist())
290+
top_pool = set(close_indices.tolist())
291+
292+
next_indices = []
293+
for _ in range(self.count):
294+
if not all_pool:
295+
raise AcquisitionError(
296+
"Unable to find enough valid gridpoints to satisfy requested sample count."
297+
)
298+
299+
choose_random = np.random.random() < self.random_fraction
300+
pool = all_pool if choose_random else top_pool
301+
if not pool:
302+
pool = all_pool if all_pool else top_pool
303+
if not pool:
304+
raise AcquisitionError(
305+
(
306+
"Unable to find gridpoint in decision surface that satisfies all constraints. "
307+
f"This can occur when acquisition_rtol (currently {self.decision_rtol}) is "
308+
f"too low or exclusion_radius (currently {self.exclusion_radius}) is too high."
309+
)
271310
)
272-
)
273311

274-
# randomly shuffle and gather the requested number of indices and compositions
275-
np.random.shuffle(indices)
276-
next_indices = indices[: self.count]
312+
selected_index = int(np.random.choice(list(pool)))
313+
next_indices.append(selected_index)
314+
all_pool.discard(selected_index)
315+
top_pool.discard(selected_index)
277316

278317
next_samples = dataset.sel({self.grid_dim: next_indices}).comp_grid
279318
next_samples = next_samples.rename({self.grid_dim: "AF_sample"})
@@ -335,6 +374,7 @@ def __init__(
335374
output_prefix: Optional[str] = None,
336375
output_variable: str = "next_samples",
337376
decision_rtol: float = 0.05,
377+
random_fraction: float = 0.0,
338378
excluded_comps_variables: Optional[str] = None,
339379
excluded_comps_dim: Optional[str] = None,
340380
exclusion_radius: float = 1e-3,
@@ -348,6 +388,7 @@ def __init__(
348388
output_prefix=output_prefix,
349389
output_variable=output_variable,
350390
decision_rtol=decision_rtol,
391+
random_fraction=random_fraction,
351392
excluded_comps_variables=excluded_comps_variables,
352393
excluded_comps_dim=excluded_comps_dim,
353394
exclusion_radius=exclusion_radius,
@@ -465,6 +506,7 @@ def __init__(
465506
output_prefix: Optional[str] = None,
466507
output_variable: str = "next_samples",
467508
decision_rtol: float = 0.05,
509+
random_fraction: float = 0.0,
468510
excluded_comps_variables: Optional[List[str]] = None,
469511
excluded_comps_dim: Optional[str] = None,
470512
exclusion_radius: float = 1e-3,
@@ -478,6 +520,7 @@ def __init__(
478520
output_prefix=output_prefix,
479521
output_variable=output_variable,
480522
decision_rtol=decision_rtol,
523+
random_fraction=random_fraction,
481524
excluded_comps_variables=excluded_comps_variables,
482525
excluded_comps_dim=excluded_comps_dim,
483526
exclusion_radius=exclusion_radius,
@@ -610,6 +653,7 @@ def __init__(
610653
output_prefix: Optional[str] = None,
611654
output_variable: str = "next_samples",
612655
decision_rtol: float = 0.05,
656+
random_fraction: float = 0.0,
613657
excluded_comps_variables: Optional[List[str]] = None,
614658
excluded_comps_dim: Optional[str] = None,
615659
exclusion_radius: float = 1e-3,
@@ -623,6 +667,7 @@ def __init__(
623667
output_prefix=output_prefix,
624668
output_variable=output_variable,
625669
decision_rtol=decision_rtol,
670+
random_fraction=random_fraction,
626671
excluded_comps_variables=excluded_comps_variables,
627672
excluded_comps_dim=excluded_comps_dim,
628673
exclusion_radius=exclusion_radius,
@@ -753,6 +798,7 @@ def __init__(
753798
output_prefix: Optional[str] = None,
754799
output_variable: str = "next_samples",
755800
decision_rtol: float = 0.05,
801+
random_fraction: float = 0.0,
756802
excluded_comps_variables: Optional[List[str]] = None,
757803
excluded_comps_dim: Optional[str] = None,
758804
exclusion_radius: float = 1e-3,
@@ -767,6 +813,7 @@ def __init__(
767813
output_prefix=output_prefix,
768814
output_variable=output_variable,
769815
decision_rtol=decision_rtol,
816+
random_fraction=random_fraction,
770817
excluded_comps_variables=excluded_comps_variables,
771818
excluded_comps_dim=excluded_comps_dim,
772819
exclusion_radius=exclusion_radius,

AFL/double_agent/plotting.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,7 @@ def plot_scatter_plotly(
576576
fig = go.Figure()
577577

578578
# Define marker symbols for discrete labels
579-
plotly_symbols = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up',
580-
'triangle-down', 'triangle-left', 'triangle-right', 'pentagon',
581-
'hexagon', 'star', 'hexagram', 'star-triangle-up', 'star-triangle-down']
579+
plotly_symbols = ['circle', 'circle-open', 'cross', 'x', 'diamond', 'diamond-open', 'square', 'square-open']
582580

583581
if len(components) == 3 and ternary:
584582
# Extract coordinates

0 commit comments

Comments
 (0)