Skip to content

Commit 42d11fb

Browse files
cajchristianPicoCentauri
authored andcommitted
Adding validate_data calls and updated tags to Ridge2FoldCV
1 parent 941a77f commit 42d11fb

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/skmatter/linear_model/_ridge.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import numpy as np
22
from joblib import Parallel, delayed
3-
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
3+
from sklearn.base import RegressorMixin, MultiOutputMixin, BaseEstimator
44
from sklearn.metrics import check_scoring
55
from sklearn.model_selection import KFold, check_cv
6-
from sklearn.utils import check_array
7-
from sklearn.utils.validation import check_is_fitted
6+
from sklearn.utils.validation import check_is_fitted, validate_data
87

98

10-
class Ridge2FoldCV(BaseEstimator, MultiOutputMixin, RegressorMixin):
9+
class Ridge2FoldCV(RegressorMixin, MultiOutputMixin, BaseEstimator):
1110
r"""Ridge regression with an efficient 2-fold cross-validation method using the SVD
1211
solver.
1312
@@ -20,7 +19,7 @@ class Ridge2FoldCV(BaseEstimator, MultiOutputMixin, RegressorMixin):
2019
while the alpha value is determined with a 2-fold cross-validation from a list of
2120
alpha values. It is more efficient version than doing 2-fold cross-validation
2221
naively The algorithmic trick is to reuse the matrices obtained by SVD for each
23-
regularization paramater :param alpha: The 2-fold CV can be broken donw to
22+
regularization paramater :param alpha: The 2-fold CV can be broken down to
2423
2524
.. math::
2625
@@ -136,6 +135,11 @@ def __init__(
136135
self.shuffle = shuffle
137136
self.n_jobs = n_jobs
138137

138+
def __sklearn_tags__(self):
139+
tags = super().__sklearn_tags__()
140+
tags.target_tags.single_output = False
141+
return tags
142+
139143
def _more_tags(self):
140144
return {"multioutput_only": True}
141145

@@ -195,7 +199,7 @@ def predict(self, X):
195199
Training data, where n_samples is the number of samples
196200
and n_features is the number of features.
197201
"""
198-
X = check_array(X)
202+
X = validate_data(self, X, reset=False)
199203

200204
check_is_fitted(self, ["coef_"])
201205

0 commit comments

Comments
 (0)