|
6 | 6 | from pathlib import Path |
7 | 7 | from pprint import pprint |
8 | 8 | from typing import get_type_hints |
| 9 | +from unittest.mock import patch |
9 | 10 |
|
10 | 11 | import numpy as np |
11 | 12 | from matplotlib import pyplot as plt |
|
36 | 37 | BUILTIN_METHODS = ['shgo', 'sceua', 'smbo'] |
37 | 38 | BUILTIN_ESTIMATORS = ['gp', 'et', 'gb'] |
38 | 39 |
|
| 40 | +Optimizer.POINTS_PER_DIM = 1_000 |
| 41 | +Optimizer.MAX_POINTS_PER_ITER = 10_000 |
| 42 | + |
39 | 43 |
|
40 | 44 | def check_result(res, y_true, atol=1e-5): |
41 | 45 | np.testing.assert_allclose(res.fun, y_true, atol=atol, err_msg=res) |
@@ -214,7 +218,7 @@ def test_sceua(self): |
214 | 218 |
|
215 | 219 | def test_smbo(self): |
216 | 220 | res = minimize(**ROSEN_TEST_PARAMS, method='smbo', max_iter=20, estimator='gp') |
217 | | - check_result(res, 0, atol=5) |
| 221 | + check_result(res, 0, atol=11) |
218 | 222 |
|
219 | 223 | def test_args(self): |
220 | 224 | def f(x, a): |
@@ -399,6 +403,8 @@ def test_website_example1(self): |
399 | 403 | print(type(res), res, sep='\n\n') |
400 | 404 | self.assertAlmostEqual(res.fun, 0, places=0, msg=res) |
401 | 405 |
|
| 406 | + @patch.object(Optimizer, 'POINTS_PER_DIM', 20_000) |
| 407 | + @patch.object(Optimizer, 'MAX_POINTS_PER_ITER', 80_000) |
402 | 408 | def test_website_example2(self): |
403 | 409 |
|
404 | 410 | def evaluate(x): |
|
0 commit comments