@@ -226,7 +226,7 @@ def _fit(self):
226226 estimator = self .estimator
227227 if self .n_models > 1 and hasattr (estimator , 'random_state' ):
228228 estimator = clone (self .estimator )
229- estimator .random_state = np . random .randint (10000000 )
229+ estimator .random_state = self . rng .randint (10000000 )
230230 estimator .fit (self ._X , self ._y )
231231
232232 self .estimators .append (estimator )
@@ -238,12 +238,14 @@ def _fit(self):
238238 def _predict (self , X ):
239239 means , stds , masks = [], [], []
240240 for estimator in self .estimators :
241+ X_batched = [X [i :i + 10_000 ] for i in range (0 , len (X ), 10_000 )]
241242 try :
242- mean , std = estimator .predict (X , return_std = True )
243+ mean , std = np .concatenate (
244+ [estimator .predict (X , return_std = True ) for X in X_batched ], axis = 1 )
243245 except TypeError as exc :
244246 if 'return_std' not in exc .args [0 ]:
245247 raise
246- mean , std = estimator .predict (X ), 0
248+ mean , std = np . concatenate ([ estimator .predict (X ) for X in X_batched ] ), 0
247249 mask = np .ones_like (mean , dtype = bool )
248250 else :
249251 # Only suggest new/unknown points
@@ -336,7 +338,11 @@ def ask(
336338 X = _sample_population (self .bounds , n_points , self .constraints , self .rng )
337339 X , mean , std = self ._predict (X )
338340 criterion = acq_func (mean = mean , std = std , kappa = kappa )
339- best_indices = np .argsort (criterion )[:, :n_candidates ].flatten ('F' )
341+ n_candidates = min (n_candidates , criterion .shape [1 ])
342+ best_indices = np .take_along_axis (
343+ partitioned_inds := np .argpartition (criterion , n_candidates - 1 )[:, :n_candidates ],
344+ np .argsort (np .take_along_axis (criterion , partitioned_inds , axis = 1 )),
345+ axis = 1 ).flatten ('F' )
340346 X = X [best_indices ]
341347 X = X [:n_candidates ]
342348 self ._X_ask .extend (map (tuple , X ))
0 commit comments