Skip to content

Commit 7ac80ee

Browse files
committed
ENH: Better init SMBO 'gp' estimator
Improves benchmark position
1 parent 056f1f4 commit 7ac80ee

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

sambo/_estimators.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from numbers import Integral, Real
23
from typing import Literal, Optional
34

@@ -9,11 +10,10 @@
910
raise ImportError(
1011
'Missing package sklearn. Please install scikit-learn '
1112
'(`pip install scikit-learn`).') from None
12-
from sklearn.ensemble import (
13-
ExtraTreesRegressor,
14-
GradientBoostingRegressor,
15-
)
16-
from sklearn.gaussian_process import GaussianProcessRegressor
13+
from sklearn.ensemble import ExtraTreesRegressor, GradientBoostingRegressor
14+
from sklearn.exceptions import ConvergenceWarning
15+
from sklearn.gaussian_process import GaussianProcessRegressor as _GaussianProcessRegressor
16+
from sklearn.gaussian_process.kernels import ConstantKernel, RBF
1717
from sklearn.model_selection._search import BaseSearchCV
1818

1919
from sambo._util import _SklearnLikeRegressor, lru_cache
@@ -47,6 +47,17 @@ def predict(self, X, return_std=False):
4747
return y_pred, std
4848

4949

50+
class GaussianProcessRegressor(_GaussianProcessRegressor):
51+
def fit(self, X, y):
52+
with warnings.catch_warnings(action='ignore', category=ConvergenceWarning):
53+
return super().fit(X, y)
54+
55+
def predict(self, X, **kwargs):
56+
with warnings.catch_warnings():
57+
warnings.filterwarnings('ignore', 'Predicted variances smaller than 0', UserWarning)
58+
return super().predict(X, **kwargs)
59+
60+
5061
class ExtraTreesRegressorWithStd(_RegressorWithStdMixin, ExtraTreesRegressor):
5162
"""
5263
Like `ExtraTreesRegressor` from scikit-learn, but with
@@ -75,6 +86,9 @@ def _estimator_factory(estimator, bounds, rng):
7586

7687
if estimator == 'gp':
7788
return GaussianProcessRegressor(
89+
kernel=(ConstantKernel(constant_value=1, constant_value_bounds=(1e-1, 1e1)) *
90+
RBF(length_scale=np.repeat(1, len(bounds)), length_scale_bounds=(1e-2, 1e2))),
91+
alpha=1e-14,
7892
copy_X_train=False,
7993
normalize_y=True,
8094
random_state=rng,

0 commit comments

Comments
 (0)