|
| 1 | +import warnings |
1 | 2 | from numbers import Integral, Real |
2 | 3 | from typing import Literal, Optional |
3 | 4 |
|
|
9 | 10 | raise ImportError( |
10 | 11 | 'Missing package sklearn. Please install scikit-learn ' |
11 | 12 | '(`pip install scikit-learn`).') from None |
12 | | -from sklearn.ensemble import ( |
13 | | - ExtraTreesRegressor, |
14 | | - GradientBoostingRegressor, |
15 | | -) |
16 | | -from sklearn.gaussian_process import GaussianProcessRegressor |
| 13 | +from sklearn.ensemble import ExtraTreesRegressor, GradientBoostingRegressor |
| 14 | +from sklearn.exceptions import ConvergenceWarning |
| 15 | +from sklearn.gaussian_process import GaussianProcessRegressor as _GaussianProcessRegressor |
| 16 | +from sklearn.gaussian_process.kernels import ConstantKernel, RBF |
17 | 17 | from sklearn.model_selection._search import BaseSearchCV |
18 | 18 |
|
19 | 19 | from sambo._util import _SklearnLikeRegressor, lru_cache |
@@ -47,6 +47,17 @@ def predict(self, X, return_std=False): |
47 | 47 | return y_pred, std |
48 | 48 |
|
49 | 49 |
|
| 50 | +class GaussianProcessRegressor(_GaussianProcessRegressor): |
| 51 | + def fit(self, X, y): |
| 52 | + with warnings.catch_warnings(action='ignore', category=ConvergenceWarning): |
| 53 | + return super().fit(X, y) |
| 54 | + |
| 55 | + def predict(self, X, **kwargs): |
| 56 | + with warnings.catch_warnings(): |
| 57 | + warnings.filterwarnings('ignore', 'Predicted variances smaller than 0', UserWarning) |
| 58 | + return super().predict(X, **kwargs) |
| 59 | + |
| 60 | + |
50 | 61 | class ExtraTreesRegressorWithStd(_RegressorWithStdMixin, ExtraTreesRegressor): |
51 | 62 | """ |
52 | 63 | Like `ExtraTreesRegressor` from scikit-learn, but with |
@@ -75,6 +86,9 @@ def _estimator_factory(estimator, bounds, rng): |
75 | 86 |
|
76 | 87 | if estimator == 'gp': |
77 | 88 | return GaussianProcessRegressor( |
| 89 | + kernel=(ConstantKernel(constant_value=1, constant_value_bounds=(1e-1, 1e1)) * |
| 90 | + RBF(length_scale=np.repeat(1, len(bounds)), length_scale_bounds=(1e-2, 1e2))), |
| 91 | + alpha=1e-14, |
78 | 92 | copy_X_train=False, |
79 | 93 | normalize_y=True, |
80 | 94 | random_state=rng, |
|
0 commit comments