44
55# Authors: scikit-learn-contrib developers
66# License: BSD 3 clause
7+ # mypy: ignore-errors
78
89import numpy as np
910from sklearn .base import BaseEstimator , ClassifierMixin , TransformerMixin , _fit_context
1011from sklearn .metrics import euclidean_distances
1112from sklearn .utils .multiclass import check_classification_targets
12- from sklearn .utils .validation import check_is_fitted
13+ from sklearn .utils .validation import check_is_fitted , validate_data
1314
1415
1516class TemplateEstimator (BaseEstimator ):
@@ -73,12 +74,14 @@ def fit(self, X, y):
7374 self : object
7475 Returns self.
7576 """
76- # `_validate_data` is defined in the `BaseEstimator` class .
77+ # `_validate_data` is defined in the sklearn.utils.validation module .
7778 # It allows to:
7879 # - run different checks on the input data;
7980 # - define some attributes associated to the input data: `n_features_in_` and
8081 # `feature_names_in_`.
81- X , y = self ._validate_data (X , y , accept_sparse = True )
82+
83+ X , y = validate_data (self , X , y , accept_sparse = True )
84+
8285 self .is_fitted_ = True
8386 # `fit` should always return `self`
8487 return self
@@ -100,7 +103,7 @@ def predict(self, X):
100103 check_is_fitted (self )
101104 # We need to set reset=False because we don't want to overwrite `n_features_in_`
102105 # `feature_names_in_` but only check that the shape is consistent.
103- X = self . _validate_data ( X , accept_sparse = True , reset = False )
106+ X = validate_data ( self , X , accept_sparse = True , reset = False )
104107 return np .ones (X .shape [0 ], dtype = np .int64 )
105108
106109
@@ -182,7 +185,7 @@ def fit(self, X, y):
182185 # - run different checks on the input data;
183186 # - define some attributes associated to the input data: `n_features_in_` and
184187 # `feature_names_in_`.
185- X , y = self . _validate_data ( X , y )
188+ X , y = validate_data ( self , X , y )
186189 # We need to make sure that we have a classification task
187190 check_classification_targets (y )
188191
@@ -216,7 +219,7 @@ def predict(self, X):
216219 # Input validation
217220 # We need to set reset=False because we don't want to overwrite `n_features_in_`
218221 # `feature_names_in_` but only check that the shape is consistent.
219- X = self . _validate_data ( X , reset = False )
222+ X = validate_data ( self , X , reset = False )
220223
221224 closest = np .argmin (euclidean_distances (X , self .X_ ), axis = 1 )
222225 return self .y_ [closest ]
@@ -272,7 +275,7 @@ def fit(self, X, y=None):
272275 self : object
273276 Returns self.
274277 """
275- X = self . _validate_data ( X , accept_sparse = True )
278+ X = validate_data ( self , X , accept_sparse = True )
276279
277280 # Return the transformer
278281 return self
@@ -297,7 +300,7 @@ def transform(self, X):
297300 # Input validation
298301 # We need to set reset=False because we don't want to overwrite `n_features_in_`
299302 # `feature_names_in_` but only check that the shape is consistent.
300- X = self . _validate_data ( X , accept_sparse = True , reset = False )
303+ X = validate_data ( self , X , accept_sparse = True , reset = False )
301304 return np .sqrt (X )
302305
303306 def _more_tags (self ):
0 commit comments