@@ -87,6 +87,7 @@ def __init__(
8787 output_prefix : Optional [str ] = None ,
8888 output_variable : str = "next_compositions" ,
8989 decision_rtol : float = 0.05 ,
90+ random_fraction : float = 0.0 ,
9091 excluded_comps_variables : Optional [List [str ]] = None ,
9192 excluded_comps_dim : Optional [str ] = None ,
9293 exclusion_radius : float = 1e-3 ,
@@ -108,6 +109,11 @@ def __init__(
108109 self .grid_variable = grid_variable
109110 self .grid_dim = grid_dim
110111 self .decision_rtol = decision_rtol
112+ if random_fraction < 0.0 or random_fraction > 1.0 :
113+ raise ValueError (
114+ f"random_fraction must be within [0, 1], got { random_fraction } ."
115+ )
116+ self .random_fraction = random_fraction
111117 self .exclusion_radius = exclusion_radius
112118
113119 def calculate (self , dataset : xr .Dataset ) -> Self :
@@ -225,8 +231,9 @@ def get_next_samples(self, dataset: xr.Dataset) -> None:
225231 """Choose the next compositions by evaluating the decision surface.
226232
227233 This method finds all compositions that are within decision_rtol of the maximum values
228- of the decision surface. From this set of compositions, it randomly chooses count
229- compositions as the next sample compositions.
234+ of the decision surface, then performs epsilon-greedy sampling for each requested pick:
235+ with probability random_fraction, pick from all valid points; otherwise pick from the
236+ top decision_rtol set.
230237
231238 Parameters
232239 ----------
@@ -252,28 +259,60 @@ def get_next_samples(self, dataset: xr.Dataset) -> None:
252259 {self .grid_dim : np .arange (dataset .sizes [self .grid_dim ])}
253260 )
254261
262+ decision_values = dataset ["decision_surface" ].values
263+ valid_mask = np .isfinite (decision_values )
264+ if not np .any (valid_mask ):
265+ raise AcquisitionError (
266+ "Decision surface does not contain any finite values."
267+ )
268+
269+ all_indices = dataset [self .grid_dim ].values
270+ valid_indices = all_indices [valid_mask ]
271+ if len (valid_indices ) < self .count :
272+ raise AcquisitionError (
273+ (
274+ "Unable to find enough valid gridpoints in decision surface to "
275+ f"sample { self .count } points."
276+ )
277+ )
278+
255279 # find indices of all samples within self.decision_rtol of the maximum
256280 close_mask = np .isclose (
257- dataset . decision_surface ,
258- dataset . decision_surface . max (),
281+ decision_values ,
282+ np . max (decision_values [ valid_mask ] ),
259283 rtol = self .decision_rtol ,
260284 atol = 0 ,
261285 )
262- indices = dataset .sel ({self .grid_dim : close_mask })[self .grid_dim ].values
286+ close_mask &= valid_mask
287+ close_indices = all_indices [close_mask ]
263288
264- if len (indices ) < self .count :
265- raise AcquisitionError (
266- (
267- """Unable to find gridpoint in decision surface that satisfies all constraints. """
268- f"""This often occurs when acquisition_rtol (currently { self .decision_rtol } ) """
269- f"""is too low or when the exclusion_radius (currently { self .exclusion_radius } ) """
270- """is too high for the current problem state."""
289+ all_pool = set (valid_indices .tolist ())
290+ top_pool = set (close_indices .tolist ())
291+
292+ next_indices = []
293+ for _ in range (self .count ):
294+ if not all_pool :
295+ raise AcquisitionError (
296+ "Unable to find enough valid gridpoints to satisfy requested sample count."
297+ )
298+
299+ choose_random = np .random .random () < self .random_fraction
300+ pool = all_pool if choose_random else top_pool
301+ if not pool :
302+ pool = all_pool if all_pool else top_pool
303+ if not pool :
304+ raise AcquisitionError (
305+ (
306+ "Unable to find gridpoint in decision surface that satisfies all constraints. "
307+ f"This can occur when acquisition_rtol (currently { self .decision_rtol } ) is "
308+ f"too low or exclusion_radius (currently { self .exclusion_radius } ) is too high."
309+ )
271310 )
272- )
273311
274- # randomly shuffle and gather the requested number of indices and compositions
275- np .random .shuffle (indices )
276- next_indices = indices [: self .count ]
312+ selected_index = int (np .random .choice (list (pool )))
313+ next_indices .append (selected_index )
314+ all_pool .discard (selected_index )
315+ top_pool .discard (selected_index )
277316
278317 next_samples = dataset .sel ({self .grid_dim : next_indices }).comp_grid
279318 next_samples = next_samples .rename ({self .grid_dim : "AF_sample" })
@@ -335,6 +374,7 @@ def __init__(
335374 output_prefix : Optional [str ] = None ,
336375 output_variable : str = "next_samples" ,
337376 decision_rtol : float = 0.05 ,
377+ random_fraction : float = 0.0 ,
338378 excluded_comps_variables : Optional [str ] = None ,
339379 excluded_comps_dim : Optional [str ] = None ,
340380 exclusion_radius : float = 1e-3 ,
@@ -348,6 +388,7 @@ def __init__(
348388 output_prefix = output_prefix ,
349389 output_variable = output_variable ,
350390 decision_rtol = decision_rtol ,
391+ random_fraction = random_fraction ,
351392 excluded_comps_variables = excluded_comps_variables ,
352393 excluded_comps_dim = excluded_comps_dim ,
353394 exclusion_radius = exclusion_radius ,
@@ -465,6 +506,7 @@ def __init__(
465506 output_prefix : Optional [str ] = None ,
466507 output_variable : str = "next_samples" ,
467508 decision_rtol : float = 0.05 ,
509+ random_fraction : float = 0.0 ,
468510 excluded_comps_variables : Optional [List [str ]] = None ,
469511 excluded_comps_dim : Optional [str ] = None ,
470512 exclusion_radius : float = 1e-3 ,
@@ -478,6 +520,7 @@ def __init__(
478520 output_prefix = output_prefix ,
479521 output_variable = output_variable ,
480522 decision_rtol = decision_rtol ,
523+ random_fraction = random_fraction ,
481524 excluded_comps_variables = excluded_comps_variables ,
482525 excluded_comps_dim = excluded_comps_dim ,
483526 exclusion_radius = exclusion_radius ,
@@ -610,6 +653,7 @@ def __init__(
610653 output_prefix : Optional [str ] = None ,
611654 output_variable : str = "next_samples" ,
612655 decision_rtol : float = 0.05 ,
656+ random_fraction : float = 0.0 ,
613657 excluded_comps_variables : Optional [List [str ]] = None ,
614658 excluded_comps_dim : Optional [str ] = None ,
615659 exclusion_radius : float = 1e-3 ,
@@ -623,6 +667,7 @@ def __init__(
623667 output_prefix = output_prefix ,
624668 output_variable = output_variable ,
625669 decision_rtol = decision_rtol ,
670+ random_fraction = random_fraction ,
626671 excluded_comps_variables = excluded_comps_variables ,
627672 excluded_comps_dim = excluded_comps_dim ,
628673 exclusion_radius = exclusion_radius ,
@@ -753,6 +798,7 @@ def __init__(
753798 output_prefix : Optional [str ] = None ,
754799 output_variable : str = "next_samples" ,
755800 decision_rtol : float = 0.05 ,
801+ random_fraction : float = 0.0 ,
756802 excluded_comps_variables : Optional [List [str ]] = None ,
757803 excluded_comps_dim : Optional [str ] = None ,
758804 exclusion_radius : float = 1e-3 ,
@@ -767,6 +813,7 @@ def __init__(
767813 output_prefix = output_prefix ,
768814 output_variable = output_variable ,
769815 decision_rtol = decision_rtol ,
816+ random_fraction = random_fraction ,
770817 excluded_comps_variables = excluded_comps_variables ,
771818 excluded_comps_dim = excluded_comps_dim ,
772819 exclusion_radius = exclusion_radius ,
0 commit comments