Skip to content

Commit d8aa5d6

Browse files
committed
updated validate_data to versions > 1.6
1 parent fc9f82e commit d8aa5d6

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,10 @@ target/
7575
.LSOverride
7676

7777
# auto-generated files
78-
skltemplate/_version.py
78+
skltemplate/_version.py
79+
80+
# linters and formatters
81+
.vscode/
82+
.idea/
83+
.mypy_cache/
84+
.ruff_cache/

skltemplate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Authors: scikit-learn-contrib developers
22
# License: BSD 3 clause
33

4+
from sklearn import __version__
5+
46
from ._template import TemplateClassifier, TemplateEstimator, TemplateTransformer
5-
from ._version import __version__
67

78
__all__ = [
89
"TemplateEstimator",

skltemplate/_template.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
# Authors: scikit-learn-contrib developers
66
# License: BSD 3 clause
7+
# mypy: ignore-errors
78

89
import numpy as np
910
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, _fit_context
1011
from sklearn.metrics import euclidean_distances
1112
from sklearn.utils.multiclass import check_classification_targets
12-
from sklearn.utils.validation import check_is_fitted
13+
from sklearn.utils.validation import check_is_fitted, validate_data
1314

1415

1516
class TemplateEstimator(BaseEstimator):
@@ -73,12 +74,14 @@ def fit(self, X, y):
7374
self : object
7475
Returns self.
7576
"""
76-
# `_validate_data` is defined in the `BaseEstimator` class.
77+
# `_validate_data` is defined in the sklearn.utils.validation module.
7778
# It allows to:
7879
# - run different checks on the input data;
7980
# - define some attributes associated to the input data: `n_features_in_` and
8081
# `feature_names_in_`.
81-
X, y = self._validate_data(X, y, accept_sparse=True)
82+
83+
X, y = validate_data(self, X, y, accept_sparse=True)
84+
8285
self.is_fitted_ = True
8386
# `fit` should always return `self`
8487
return self
@@ -100,7 +103,7 @@ def predict(self, X):
100103
check_is_fitted(self)
101104
# We need to set reset=False because we don't want to overwrite `n_features_in_`
102105
# `feature_names_in_` but only check that the shape is consistent.
103-
X = self._validate_data(X, accept_sparse=True, reset=False)
106+
X = validate_data(self, X, accept_sparse=True, reset=False)
104107
return np.ones(X.shape[0], dtype=np.int64)
105108

106109

@@ -182,7 +185,7 @@ def fit(self, X, y):
182185
# - run different checks on the input data;
183186
# - define some attributes associated to the input data: `n_features_in_` and
184187
# `feature_names_in_`.
185-
X, y = self._validate_data(X, y)
188+
X, y = validate_data(self, X, y)
186189
# We need to make sure that we have a classification task
187190
check_classification_targets(y)
188191

@@ -216,7 +219,7 @@ def predict(self, X):
216219
# Input validation
217220
# We need to set reset=False because we don't want to overwrite `n_features_in_`
218221
# `feature_names_in_` but only check that the shape is consistent.
219-
X = self._validate_data(X, reset=False)
222+
X = validate_data(self, X, reset=False)
220223

221224
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
222225
return self.y_[closest]
@@ -272,7 +275,7 @@ def fit(self, X, y=None):
272275
self : object
273276
Returns self.
274277
"""
275-
X = self._validate_data(X, accept_sparse=True)
278+
X = validate_data(self, X, accept_sparse=True)
276279

277280
# Return the transformer
278281
return self
@@ -297,7 +300,7 @@ def transform(self, X):
297300
# Input validation
298301
# We need to set reset=False because we don't want to overwrite `n_features_in_`
299302
# `feature_names_in_` but only check that the shape is consistent.
300-
X = self._validate_data(X, accept_sparse=True, reset=False)
303+
X = validate_data(self, X, accept_sparse=True, reset=False)
301304
return np.sqrt(X)
302305

303306
def _more_tags(self):

0 commit comments

Comments
 (0)