Skip to content

Commit 857a168

Browse files
committed
Plot: Replace magic with DEFAULT_GRID_SIZE + fix a small arg bug
1 parent cb6e60f commit 857a168

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

sambo/_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,6 @@ def test_plot_objective(self):
117117
_plot_objective(self.RESULT_SMBO, plot_max_points=5)
118118
with self.assertWarns(UserWarning):
119119
_plot_objective(self.RESULT_SCEUA)
120-
with self.assertWarns(UserWarning):
121-
_plot_objective(self.RESULT_SCEUA, estimator='gp')
122-
with self.assertWarns(UserWarning):
123-
_plot_objective(self.RESULT_SHGO, estimator='gp')
124120

125121
def test_plot_evaluations(self):
126122
plot_evaluations(self.RESULT_SMBO)

sambo/plot.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ._util import OptimizeResult, _SklearnLikeRegressor
3333

3434
_MARKER_SEQUENCE = 'osxdvP^'
35+
_DEFAULT_SUBPLOT_SIZE = 2.1
3536

3637

3738
def plot_convergence(
@@ -245,7 +246,7 @@ def _format_scatter_plot_axes(fig, axs, space, plot_dims=None, dim_labels=None,
245246
if dim_labels is None:
246247
dim_labels = [fr"$\mathbf{{x_{{{i}}}}}$" for i in plot_dims]
247248

248-
nticks = int((1 + np.log10(size / (base_figsize := 2))) * (base_nticks := 6)) # noqa: F841
249+
nticks = int((1 + np.log10(size / (base_figsize := _DEFAULT_SUBPLOT_SIZE))) * (base_nticks := 6)) # noqa: F841
249250
fontsize = 10
250251

251252
_MaxNLocator = partial(MaxNLocator, nbins=nticks)
@@ -441,9 +442,9 @@ def _subplots_grid(n_dims, size, title):
441442
_watermark(fig)
442443
if add_figure_title:
443444
fig.suptitle(title)
444-
margins = dict(left=(m := 3 / n_dims * size / 2 * .07), bottom=m, right=1 - m,
445+
margins = dict(left=(m := 3 / n_dims * size / _DEFAULT_SUBPLOT_SIZE * .07), bottom=m, right=1 - m,
445446
top=1 - (2 if add_figure_title else 1.1) * m)
446-
fig.subplots_adjust(**margins, hspace=.15, wspace=.15)
447+
fig.subplots_adjust(**margins, hspace=.1, wspace=.1)
447448
return fig, axs
448449

449450

@@ -470,7 +471,7 @@ def plot_objective(
470471
resolution: int = 16,
471472
n_samples: int = 250,
472473
estimator: Optional[str | _SklearnLikeRegressor] = None,
473-
size: float = 2,
474+
size: float = _DEFAULT_SUBPLOT_SIZE,
474475
zscale: Literal['linear', 'log'] = 'linear',
475476
names: Optional[list[str]] = None,
476477
true_minimum: Optional[list[float] | list[list[float]]] = None,
@@ -609,8 +610,9 @@ def plot_objective(
609610
if estimator is None and result_estimator is not None:
610611
estimator = result_estimator
611612
else:
613+
_estimator_arg = estimator
612614
estimator = _estimator_factory(estimator, bounds, rng=0)
613-
if result_estimator is None:
615+
if result_estimator is None and _estimator_arg is None:
614616
warnings.warn(
615617
'The optimization result process does not appear to have been '
616618
'driven by a model. You can still still observe partial dependence '
@@ -654,7 +656,7 @@ def plot_evaluations(
654656
names: Optional[list[str]] = None,
655657
plot_dims: Optional[list[int]] = None,
656658
jitter: float = .02,
657-
size: int = 2,
659+
size: int = _DEFAULT_SUBPLOT_SIZE,
658660
cmap: str = 'summer',
659661
) -> Figure:
660662
"""Visualize the order in which points were evaluated during optimization.

0 commit comments

Comments
 (0)