Skip to content

Commit 3b3c785

Browse files
committed
TST: Speed-up SMBO tests
1 parent e24d6e1 commit 3b3c785

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

sambo/_smbo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def ask(
339339
assert isinstance(kappa, (Real, Iterable)), kappa
340340
self._init_once()
341341

342-
n_points = min(80_000, 20_000 * int(len(self.bounds)**2)) # TODO: Make this a param?
342+
n_points = min(self.MAX_POINTS_PER_ITER,
343+
self.POINTS_PER_DIM * int(len(self.bounds)**2)) # TODO: Make this a param?
343344
nfev = len(self._X)
344345
if nfev < 10 * len(self.bounds)**2:
345346
X = _sample_population(self.bounds, n_points, self.constraints, self.rng)
@@ -363,6 +364,9 @@ def ask(
363364
self._X_ask.extend(map(tuple, X))
364365
return X
365366

367+
POINTS_PER_DIM = 20_000
368+
MAX_POINTS_PER_ITER = 80_000
369+
366370
def tell(self, y: float | list[float],
367371
x: Optional[float | tuple[float] | list[tuple[float]]] = None):
368372
"""

sambo/_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
from pprint import pprint
88
from typing import get_type_hints
9+
from unittest.mock import patch
910

1011
import numpy as np
1112
from matplotlib import pyplot as plt
@@ -36,6 +37,9 @@
3637
BUILTIN_METHODS = ['shgo', 'sceua', 'smbo']
3738
BUILTIN_ESTIMATORS = ['gp', 'et', 'gb']
3839

40+
Optimizer.POINTS_PER_DIM = 1_000
41+
Optimizer.MAX_POINTS_PER_ITER = 10_000
42+
3943

4044
def check_result(res, y_true, atol=1e-5):
4145
np.testing.assert_allclose(res.fun, y_true, atol=atol, err_msg=res)
@@ -214,7 +218,7 @@ def test_sceua(self):
214218

215219
def test_smbo(self):
216220
res = minimize(**ROSEN_TEST_PARAMS, method='smbo', max_iter=20, estimator='gp')
217-
check_result(res, 0, atol=5)
221+
check_result(res, 0, atol=11)
218222

219223
def test_args(self):
220224
def f(x, a):
@@ -399,6 +403,8 @@ def test_website_example1(self):
399403
print(type(res), res, sep='\n\n')
400404
self.assertAlmostEqual(res.fun, 0, places=0, msg=res)
401405

406+
@patch.object(Optimizer, 'POINTS_PER_DIM', 20_000)
407+
@patch.object(Optimizer, 'MAX_POINTS_PER_ITER', 80_000)
402408
def test_website_example2(self):
403409

404410
def evaluate(x):

0 commit comments

Comments
 (0)