Skip to content

Commit 144f1bd

Browse files
committed
TST: Speed-up SMBO tests
1 parent 2a72847 commit 144f1bd

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

sambo/_smbo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def ask(
337337
assert isinstance(kappa, (Real, Iterable)), kappa
338338
self._init_once()
339339

340-
n_points = min(80_000, 20_000 * int(len(self.bounds)**2)) # TODO: Make this a param?
340+
n_points = min(self.MAX_POINTS_PER_ITER, self.POINTS_PER_DIM * int(len(self.bounds)**2)) # TODO: Make this a param?
341341
nfev = len(self._X)
342342
if nfev < 10 * len(self.bounds)**2:
343343
X = _sample_population(self.bounds, n_points, self.constraints, self.rng)
@@ -361,6 +361,9 @@ def ask(
361361
self._X_ask.extend(map(tuple, X))
362362
return X
363363

364+
POINTS_PER_DIM = 20_000
365+
MAX_POINTS_PER_ITER = 80_000
366+
364367
def tell(self, y: float | list[float],
365368
x: Optional[float | tuple[float] | list[tuple[float]]] = None):
366369
"""

sambo/_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
from pprint import pprint
88
from typing import get_type_hints
9+
from unittest.mock import patch
910

1011
import numpy as np
1112
from matplotlib import pyplot as plt
@@ -36,6 +37,9 @@
3637
BUILTIN_METHODS = ['shgo', 'sceua', 'smbo']
3738
BUILTIN_ESTIMATORS = ['gp', 'et', 'gb']
3839

40+
Optimizer.POINTS_PER_DIM = 1_000
41+
Optimizer.MAX_POINTS_PER_ITER = 10_000
42+
3943

4044
def check_result(res, y_true, atol=1e-5):
4145
np.testing.assert_allclose(res.fun, y_true, atol=atol, err_msg=res)
@@ -214,7 +218,7 @@ def test_sceua(self):
214218

215219
def test_smbo(self):
216220
res = minimize(**ROSEN_TEST_PARAMS, method='smbo', max_iter=20, estimator='gp')
217-
check_result(res, 0, atol=5)
221+
check_result(res, 0, atol=11)
218222

219223
def test_args(self):
220224
def f(x, a):
@@ -399,6 +403,8 @@ def test_website_example1(self):
399403
print(type(res), res, sep='\n\n')
400404
self.assertAlmostEqual(res.fun, 0, places=0, msg=res)
401405

406+
@patch.object(Optimizer, 'POINTS_PER_DIM', 20_000)
407+
@patch.object(Optimizer, 'MAX_POINTS_PER_ITER', 80_000)
402408
def test_website_example2(self):
403409

404410
def evaluate(x):

0 commit comments

Comments
 (0)