Skip to content

Commit 0453e35

Browse files
committed
TST: Speed up some tests
Fix tests / website plots
1 parent 61a8d88 commit 0453e35

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

sambo/_test.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def callback(res):
158158
nonlocal self
159159
assert self.CONSTRAINTS(res.x)
160160
assert self.CONSTRAINTS(res.xv[0])
161+
return True
161162

162163
for method in BUILTIN_METHODS:
163164
with self.subTest(method=method):
@@ -255,7 +256,7 @@ def constraints(x):
255256
was_called = True
256257
return np.all(x > 0)
257258

258-
_minimize = partial(minimize, x0=[0] * 2, constraints=constraints, max_iter=10)
259+
_minimize = partial(minimize, x0=[0] * 2, constraints=constraints, max_iter=5)
259260
for method in BUILTIN_METHODS:
260261
with self.subTest(method=method):
261262
was_called = False
@@ -376,16 +377,17 @@ def test_make_doc_plots(self):
376377
KWARGS = {
377378
plot_regret: dict(true_minimum=0),
378379
plot_convergence: dict(yscale='log', true_minimum=0),
380+
plot_objective: dict(resolution=12)
379381
}
380-
with chdir((p := Path(__file__).parent).parent / f'html/{p.stem}'):
382+
with chdir(Path(__file__).parent.parent / 'www'):
381383
for plot_func in PLOT_FUNCS:
382384
name = plot_func.__name__.removeprefix("plot_")
383385
with self.subTest(plot=name):
384386
try:
385387
fig = plot_func(*zip(BUILTIN_METHODS, results),
386388
**KWARGS.get(plot_func, {}))
387389
except TypeError:
388-
fig = plot_func(results[0]) # FIXME: plot (1, 3) subplots
390+
fig = plot_func(results[0], **KWARGS.get(plot_func, {})) # FIXME: plot (1, 3) subplots
389391
fig.savefig(f'{name}.svg')
390392
plt.show()
391393

@@ -406,17 +408,18 @@ def evaluate(x):
406408

407409
results = []
408410
for estimator in BUILTIN_ESTIMATORS:
409-
optimizer = Optimizer(fun=None, bounds=[(-2, 2)]*2, estimator=estimator, rng=0)
411+
optimizer = Optimizer(fun=None, bounds=[(-2, 2)]*4, estimator=estimator, rng=0)
410412

411-
for i in range(100):
412-
suggested_x = optimizer.ask(n_candidates=1)
413+
for i in range(30):
414+
suggested_x = optimizer.ask(n_candidates=2)
413415
y = [evaluate(x) for x in suggested_x]
414416
optimizer.tell(y)
415417

416418
result = optimizer.run()
417419
results.append(result)
418-
fig = plot_convergence(*[(f'estimator={e!r}', r) for e, r in zip(BUILTIN_ESTIMATORS, results)], yscale='log', true_minimum=0)
419-
with chdir(Path(__file__).parent.parent / 'html'):
420+
named_results = [(f'estimator={e!r}', r) for e, r in zip(BUILTIN_ESTIMATORS, results)]
421+
fig = plot_convergence(*named_results, true_minimum=0)
422+
with chdir(Path(__file__).parent.parent / 'www'):
420423
fig.savefig('convergence2.svg')
421424
plt.show()
422425

@@ -428,22 +431,22 @@ def test_website_example3(self):
428431
X, y = load_breast_cancer(return_X_y=True)
429432
clf = DecisionTreeClassifier(random_state=0)
430433
param_grid = {
431-
'max_depth': list(range(1, 30)),
432-
'min_samples_split': [2, 5, 10, 20, 50, 100],
433-
'min_samples_leaf': list(range(1, 20)),
434+
'max_depth': list(range(1, 20)),
435+
'min_samples_split': [2, 5, 10, 50, 100],
436+
'min_samples_leaf': list(range(1, 10)),
434437
'criterion': ['gini', 'entropy'],
435438
'max_features': [None, 'sqrt', 'log2'],
436439
}
437-
search = GridSearchCV(clf, param_grid, cv=2, n_jobs=-1)
438-
# Trying all ~20k combinations may take a long time ...
440+
search = GridSearchCV(clf, param_grid, cv=2, n_jobs=1)
441+
# Trying all ~6k combinations may take a long time ...
439442
search.fit(X, y)
440443
pprint(dict(sorted(search.best_params_.items())))
441444
print(search.best_score_)
442445

443446
# Alternatively ...
444447
from sambo import SamboSearchCV
445448

446-
search = SamboSearchCV(clf, param_grid, max_iter=100, cv=2, n_jobs=-1, rng=0)
449+
search = SamboSearchCV(clf, param_grid, max_iter=100, cv=2, n_jobs=1, rng=0)
447450
search.fit(X, y) # Fast, good enough
448451
pprint(dict(sorted(search.best_params_.items())))
449452
print(search.best_score_)

sambo/plot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def plot_convergence(
6464
6565
Example
6666
-------
67-
.. image:: convergence.svg
67+
.. image:: /convergence.svg
6868
"""
6969
assert results, results
7070

@@ -158,7 +158,7 @@ def plot_regret(
158158
159159
Example
160160
-------
161-
.. image:: regret.svg
161+
.. image:: /regret.svg
162162
"""
163163
assert results, results
164164

@@ -568,7 +568,7 @@ def plot_objective(
568568
569569
Example
570570
-------
571-
.. image:: objective.svg
571+
.. image:: /objective.svg
572572
"""
573573
result = _check_result(result)
574574
space = _check_space(result)
@@ -708,7 +708,7 @@ def plot_evaluations(
708708
709709
Example
710710
-------
711-
.. image:: evaluations.svg
711+
.. image:: /evaluations.svg
712712
"""
713713
result = _check_result(result)
714714
space = _check_space(result)

0 commit comments

Comments
 (0)