Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions AFL/double_agent/AcquisitionFunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
output_prefix: Optional[str] = None,
output_variable: str = "next_compositions",
decision_rtol: float = 0.05,
random_fraction: float = 0.0,
excluded_comps_variables: Optional[List[str]] = None,
excluded_comps_dim: Optional[str] = None,
exclusion_radius: float = 1e-3,
Expand All @@ -108,6 +109,11 @@ def __init__(
self.grid_variable = grid_variable
self.grid_dim = grid_dim
self.decision_rtol = decision_rtol
if random_fraction < 0.0 or random_fraction > 1.0:
raise ValueError(
f"random_fraction must be within [0, 1], got {random_fraction}."
)
self.random_fraction = random_fraction
self.exclusion_radius = exclusion_radius

def calculate(self, dataset: xr.Dataset) -> Self:
Expand Down Expand Up @@ -225,8 +231,9 @@ def get_next_samples(self, dataset: xr.Dataset) -> None:
"""Choose the next compositions by evaluating the decision surface.

This method finds all compositions that are within decision_rtol of the maximum values
of the decision surface. From this set of compositions, it randomly chooses count
compositions as the next sample compositions.
of the decision surface, then performs epsilon-greedy sampling for each requested pick:
with probability random_fraction, pick from all valid points; otherwise pick from the
top decision_rtol set.

Parameters
----------
Expand All @@ -252,28 +259,60 @@ def get_next_samples(self, dataset: xr.Dataset) -> None:
{self.grid_dim: np.arange(dataset.sizes[self.grid_dim])}
)

decision_values = dataset["decision_surface"].values
valid_mask = np.isfinite(decision_values)
if not np.any(valid_mask):
raise AcquisitionError(
"Decision surface does not contain any finite values."
)

all_indices = dataset[self.grid_dim].values
valid_indices = all_indices[valid_mask]
if len(valid_indices) < self.count:
raise AcquisitionError(
(
"Unable to find enough valid gridpoints in decision surface to "
f"sample {self.count} points."
)
)

# find indices of all samples within self.decision_rtol of the maximum
close_mask = np.isclose(
dataset.decision_surface,
dataset.decision_surface.max(),
decision_values,
np.max(decision_values[valid_mask]),
rtol=self.decision_rtol,
atol=0,
)
indices = dataset.sel({self.grid_dim: close_mask})[self.grid_dim].values
close_mask &= valid_mask
close_indices = all_indices[close_mask]

if len(indices) < self.count:
raise AcquisitionError(
(
"""Unable to find gridpoint in decision surface that satisfies all constraints. """
f"""This often occurs when acquisition_rtol (currently {self.decision_rtol}) """
f"""is too low or when the exclusion_radius (currently {self.exclusion_radius}) """
"""is too high for the current problem state."""
all_pool = set(valid_indices.tolist())
top_pool = set(close_indices.tolist())

next_indices = []
for _ in range(self.count):
if not all_pool:
raise AcquisitionError(
"Unable to find enough valid gridpoints to satisfy requested sample count."
)

choose_random = np.random.random() < self.random_fraction
pool = all_pool if choose_random else top_pool
if not pool:
pool = all_pool if all_pool else top_pool
if not pool:
raise AcquisitionError(
(
"Unable to find gridpoint in decision surface that satisfies all constraints. "
f"This can occur when acquisition_rtol (currently {self.decision_rtol}) is "
f"too low or exclusion_radius (currently {self.exclusion_radius}) is too high."
)
)
)

# randomly shuffle and gather the requested number of indices and compositions
np.random.shuffle(indices)
next_indices = indices[: self.count]
selected_index = int(np.random.choice(list(pool)))
next_indices.append(selected_index)
all_pool.discard(selected_index)
top_pool.discard(selected_index)

next_samples = dataset.sel({self.grid_dim: next_indices}).comp_grid
next_samples = next_samples.rename({self.grid_dim: "AF_sample"})
Expand Down Expand Up @@ -335,6 +374,7 @@ def __init__(
output_prefix: Optional[str] = None,
output_variable: str = "next_samples",
decision_rtol: float = 0.05,
random_fraction: float = 0.0,
excluded_comps_variables: Optional[str] = None,
excluded_comps_dim: Optional[str] = None,
exclusion_radius: float = 1e-3,
Expand All @@ -348,6 +388,7 @@ def __init__(
output_prefix=output_prefix,
output_variable=output_variable,
decision_rtol=decision_rtol,
random_fraction=random_fraction,
excluded_comps_variables=excluded_comps_variables,
excluded_comps_dim=excluded_comps_dim,
exclusion_radius=exclusion_radius,
Expand Down Expand Up @@ -465,6 +506,7 @@ def __init__(
output_prefix: Optional[str] = None,
output_variable: str = "next_samples",
decision_rtol: float = 0.05,
random_fraction: float = 0.0,
excluded_comps_variables: Optional[List[str]] = None,
excluded_comps_dim: Optional[str] = None,
exclusion_radius: float = 1e-3,
Expand All @@ -478,6 +520,7 @@ def __init__(
output_prefix=output_prefix,
output_variable=output_variable,
decision_rtol=decision_rtol,
random_fraction=random_fraction,
excluded_comps_variables=excluded_comps_variables,
excluded_comps_dim=excluded_comps_dim,
exclusion_radius=exclusion_radius,
Expand Down Expand Up @@ -610,6 +653,7 @@ def __init__(
output_prefix: Optional[str] = None,
output_variable: str = "next_samples",
decision_rtol: float = 0.05,
random_fraction: float = 0.0,
excluded_comps_variables: Optional[List[str]] = None,
excluded_comps_dim: Optional[str] = None,
exclusion_radius: float = 1e-3,
Expand All @@ -623,6 +667,7 @@ def __init__(
output_prefix=output_prefix,
output_variable=output_variable,
decision_rtol=decision_rtol,
random_fraction=random_fraction,
excluded_comps_variables=excluded_comps_variables,
excluded_comps_dim=excluded_comps_dim,
exclusion_radius=exclusion_radius,
Expand Down Expand Up @@ -753,6 +798,7 @@ def __init__(
output_prefix: Optional[str] = None,
output_variable: str = "next_samples",
decision_rtol: float = 0.05,
random_fraction: float = 0.0,
excluded_comps_variables: Optional[List[str]] = None,
excluded_comps_dim: Optional[str] = None,
exclusion_radius: float = 1e-3,
Expand All @@ -767,6 +813,7 @@ def __init__(
output_prefix=output_prefix,
output_variable=output_variable,
decision_rtol=decision_rtol,
random_fraction=random_fraction,
excluded_comps_variables=excluded_comps_variables,
excluded_comps_dim=excluded_comps_dim,
exclusion_radius=exclusion_radius,
Expand Down
4 changes: 1 addition & 3 deletions AFL/double_agent/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,7 @@ def plot_scatter_plotly(
fig = go.Figure()

# Define marker symbols for discrete labels
plotly_symbols = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up',
'triangle-down', 'triangle-left', 'triangle-right', 'pentagon',
'hexagon', 'star', 'hexagram', 'star-triangle-up', 'star-triangle-down']
plotly_symbols = ['circle', 'circle-open', 'cross', 'x', 'diamond', 'diamond-open', 'square', 'square-open']

if len(components) == 3 and ternary:
# Extract coordinates
Expand Down
Loading