Skip to content

Commit d38d375

Browse files
cajchristianPicoCentauri
authored andcommitted
Added validate_data calls to KPCovR
1 parent 29c5a9d commit d38d375

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/skmatter/decomposition/_kernel_pcovr.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from sklearn.kernel_ridge import KernelRidge
1010
from sklearn.linear_model._base import LinearModel
1111
from sklearn.metrics.pairwise import pairwise_kernels
12-
from sklearn.utils import check_array, check_random_state
12+
from sklearn.utils import check_random_state
1313
from sklearn.utils._arpack import _init_arpack_v0
1414
from sklearn.utils.extmath import randomized_svd, stable_cumsum, svd_flip
15-
from sklearn.utils.validation import check_is_fitted, check_X_y
15+
from sklearn.utils.validation import check_is_fitted, validate_data
1616

1717
from ..preprocessing import KernelNormalizer
1818
from ..utils import check_krr_fit, pcovr_kernel
@@ -270,7 +270,7 @@ def fit(self, X, Y, W=None):
270270
):
271271
raise ValueError("Regressor must be an instance of `KernelRidge`")
272272

273-
X, Y = check_X_y(X, Y, y_numeric=True, multi_output=True)
273+
X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True)
274274
self.X_fit_ = X.copy()
275275

276276
if self.n_components is None:
@@ -387,7 +387,7 @@ def predict(self, X=None):
387387
"""Predicts the property values"""
388388
check_is_fitted(self, ["pky_", "pty_"])
389389

390-
X = check_array(X)
390+
X = validate_data(self, X, reset=False)
391391
K = self._get_kernel(X, self.X_fit_)
392392
if self.center:
393393
K = self.centerer_.transform(K)
@@ -408,7 +408,7 @@ def transform(self, X):
408408
"""
409409
check_is_fitted(self, ["pkt_", "X_fit_"])
410410

411-
X = check_array(X)
411+
X = validate_data(self, X, reset=False)
412412
K = self._get_kernel(X, self.X_fit_)
413413

414414
if self.center:
@@ -440,7 +440,7 @@ def inverse_transform(self, T):
440440
"""
441441
return T @ self.ptx_
442442

443-
def score(self, X, Y):
443+
def score(self, X, y):
444444
r"""Computes the (negative) loss values for KernelPCovR on the given predictor
445445
and response variables. The loss in :math:`\mathbf{K}`, as explained in
446446
[Helfrecht2020]_ does not correspond to a traditional Gram loss
@@ -474,7 +474,7 @@ def score(self, X, Y):
474474
"""
475475
check_is_fitted(self, ["pkt_", "X_fit_"])
476476

477-
X = check_array(X)
477+
X, y = validate_data(self, X, y, reset=False)
478478

479479
K_NN = self._get_kernel(self.X_fit_, self.X_fit_)
480480
K_VN = self._get_kernel(X, self.X_fit_)
@@ -485,8 +485,8 @@ def score(self, X, Y):
485485
K_VN = self.centerer_.transform(K_VN)
486486
K_VV = self.centerer_.transform(K_VV)
487487

488-
y = K_VN @ self.pky_
489-
Lkrr = np.linalg.norm(Y - y) ** 2 / np.linalg.norm(Y) ** 2
488+
ypred = K_VN @ self.pky_
489+
Lkrr = np.linalg.norm(y - ypred) ** 2 / np.linalg.norm(y) ** 2
490490

491491
t_n = K_NN @ self.pkt_
492492
t_v = K_VN @ self.pkt_

0 commit comments

Comments
 (0)