Skip to content

Commit 056f1f4

Browse files
committed
REF: SMBO: Minor improve performance, fix rng bug
1 parent 46e92eb commit 056f1f4

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

sambo/_smbo.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _fit(self):
226226
estimator = self.estimator
227227
if self.n_models > 1 and hasattr(estimator, 'random_state'):
228228
estimator = clone(self.estimator)
229-
estimator.random_state = np.random.randint(10000000)
229+
estimator.random_state = self.rng.randint(10000000)
230230
estimator.fit(self._X, self._y)
231231

232232
self.estimators.append(estimator)
@@ -238,12 +238,14 @@ def _fit(self):
238238
def _predict(self, X):
239239
means, stds, masks = [], [], []
240240
for estimator in self.estimators:
241+
X_batched = [X[i:i+10_000] for i in range(0, len(X), 10_000)]
241242
try:
242-
mean, std = estimator.predict(X, return_std=True)
243+
mean, std = np.concatenate(
244+
[estimator.predict(X, return_std=True) for X in X_batched], axis=1)
243245
except TypeError as exc:
244246
if 'return_std' not in exc.args[0]:
245247
raise
246-
mean, std = estimator.predict(X), 0
248+
mean, std = np.concatenate([estimator.predict(X) for X in X_batched]), 0
247249
mask = np.ones_like(mean, dtype=bool)
248250
else:
249251
# Only suggest new/unknown points
@@ -336,7 +338,11 @@ def ask(
336338
X = _sample_population(self.bounds, n_points, self.constraints, self.rng)
337339
X, mean, std = self._predict(X)
338340
criterion = acq_func(mean=mean, std=std, kappa=kappa)
339-
best_indices = np.argsort(criterion)[:, :n_candidates].flatten('F')
341+
n_candidates = min(n_candidates, criterion.shape[1])
342+
best_indices = np.take_along_axis(
343+
partitioned_inds := np.argpartition(criterion, n_candidates - 1)[:, :n_candidates],
344+
np.argsort(np.take_along_axis(criterion, partitioned_inds, axis=1)),
345+
axis=1).flatten('F')
340346
X = X[best_indices]
341347
X = X[:n_candidates]
342348
self._X_ask.extend(map(tuple, X))

sambo/_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def evaluate(x):
409409
optimizer = Optimizer(fun=None, bounds=[(-2, 2)]*4, estimator=estimator, rng=0)
410410

411411
for i in range(30):
412-
suggested_x = optimizer.ask(n_candidates=2)
412+
suggested_x = optimizer.ask(n_candidates=1)
413413
y = [evaluate(x) for x in suggested_x]
414414
optimizer.tell(y)
415415

0 commit comments

Comments
 (0)